diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 180c3cbe..31bd8e1b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -19,4 +19,6 @@ endif() if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE OR MLLM_BUILD_QNN_BACKEND) add_subdirectory(qwen3_qnn_aot) + add_subdirectory(qwen2_qnn_aot) + add_subdirectory(llama_qnn_aot) endif() diff --git a/examples/llama_qnn_aot/CMakeLists.txt b/examples/llama_qnn_aot/CMakeLists.txt new file mode 100644 index 00000000..029d8d1e --- /dev/null +++ b/examples/llama_qnn_aot/CMakeLists.txt @@ -0,0 +1,14 @@ +# AOT targets run on x86 +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + add_executable(mllm-llama-aot-c compile.cpp) + target_link_libraries(mllm-llama-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-llama-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) + + add_executable(mllm-llama-aot-c-sha compile_sha.cpp) + target_link_libraries(mllm-llama-aot-c-sha PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-llama-aot-c-sha PRIVATE ${MLLM_INCLUDE_DIR}) +endif() + +add_executable(mllm-llama-aot-runner aot_run.cpp) +target_link_libraries(mllm-llama-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) +target_include_directories(mllm-llama-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/llama_qnn_aot/aot_run.cpp b/examples/llama_qnn_aot/aot_run.cpp new file mode 100644 index 00000000..c1918353 --- /dev/null +++ b/examples/llama_qnn_aot/aot_run.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include +#include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" +#include "configuration_llama3.hpp" +#include "mllm/models/llama/tokenization_tiny_llama.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" + +using mllm::Argparse; +using namespace mllm::qnn::aot; // NOLINT + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model").help("Model path").def("llama_qnn.mllm"); + auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); + auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); + auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); + auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); + auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); + + Argparse::parse(argc, argv); + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + mllm::initQnnBackend(model_path.get()); + + auto llama_cfg = mllm::models::llama3::Llama3Config(config_path.get()); + + RunnerConfig config; + config.num_layers = llama_cfg.num_hidden_layers; + config.num_heads = llama_cfg.num_attention_heads; + config.head_dim = llama_cfg.head_dim; + config.vocab_size = llama_cfg.vocab_size; + config.context_len = 1024; + config.ar_len = ar_len.get(); + + // Note: Using Qwen3 tokenizer as a placeholder. + // For production use, you should implement a Llama3Tokenizer or use + // the appropriate tokenizer for your model. + auto tokenizer = mllm::models::llama::TinyLlamaTokenizer(tokenizer_path.get()); + + auto input_tensor = tokenizer.convertMessage({{ + .role = "user", + .content = "hello", + }}); + + input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + + // DBG: + mllm::print(input_tensor["sequence"].shape()); + mllm::print(input_tensor["sequence"]); + + Runner runner(config, &tokenizer); + if (!runner.load()) { + std::cerr << "Failed to load model\n"; + return 1; + } + + runner.generate( + input_tensor["sequence"], gen_len.get(), [](const std::string& token) { std::cout << token << std::flush; }, true); + std::cout << "\n"; + + return 0; +}); diff --git a/examples/llama_qnn_aot/compile.cpp b/examples/llama_qnn_aot/compile.cpp new file mode 100644 index 00000000..3568a2f4 --- /dev/null +++ b/examples/llama_qnn_aot/compile.cpp @@ -0,0 +1,159 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "modeling_llama_qnn_aot.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::llama3::Llama3Config(model_cfg_path.get()); + auto model = mllm::models::llama3::LlamaForCausalLM(model_cfg); + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "llama-lpbq.bin"); +}); diff --git a/examples/llama_qnn_aot/compile_sha.cpp b/examples/llama_qnn_aot/compile_sha.cpp new file mode 100644 index 00000000..bd938b7a --- /dev/null +++ b/examples/llama_qnn_aot/compile_sha.cpp @@ -0,0 +1,196 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Usage: +// ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json + +#include +#include +#include +#include +#include +#include + +#include "modeling_llama_qnn_aot_sha.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::llama3::Llama3Config(model_cfg_path.get()); + + // Load original parameters + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + + // ============================================================================ + // Key Step: Prepare SHA parameters by slicing MHA weights + // ============================================================================ + // This is the critical step that transforms MHA weights into SHA weights. + // For each Q/K/V projection, we slice the weight matrix into per-head pieces. + // + // Original: q_proj.weight [num_heads * head_dim, hidden_size, 1, 1] + // SHA: q_proj.{h}.weight [head_dim, hidden_size, 1, 1] for each head h + // + mllm::print("Preparing SHA parameters (slicing MHA weights)..."); + mllm::models::llama3::sha::prepareParametersForSHA(params, model_cfg); + mllm::print("SHA parameters prepared."); + + // Create SHA model + auto model = mllm::models::llama3::sha::LlamaForCausalLM_SHA(model_cfg); + + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=32)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_sha_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=1)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_sha_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "llama-lpbq-sha.bin"); + + mllm::print("SHA compilation completed successfully!"); + mllm::print("Output files:"); + mllm::print(" - llama_qnn_aot_sha_32.mir (IR dump for seq=32)"); + mllm::print(" - llama_qnn_aot_sha_1.mir (IR dump for seq=1)"); + mllm::print(" - llama-lpbq-sha.bin (QNN context)"); +}); diff --git a/examples/llama_qnn_aot/config_3B.json b/examples/llama_qnn_aot/config_3B.json new file mode 100644 index 00000000..ef7a3e94 --- /dev/null +++ b/examples/llama_qnn_aot/config_3B.json @@ -0,0 +1,28 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128009, + "attention_bias": false, + "hidden_act": "silu", + "hidden_size": 3072, + "head_dim": 128, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "llama", + "num_attention_heads": 24, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.45.0", + "use_cache": true, + "vocab_size": 128256, + "max_cache_length": 2048, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/llama_qnn_aot/configuration_llama3.hpp b/examples/llama_qnn_aot/configuration_llama3.hpp new file mode 100644 index 00000000..16375ff6 --- /dev/null +++ b/examples/llama_qnn_aot/configuration_llama3.hpp @@ -0,0 +1,97 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::llama3 { + +/** + * @brief Configuration for Llama 3.x models used in QNN AOT compilation. + * + * This configuration is designed to support Llama 3.2 3B Instruct and similar models. + * It includes all necessary fields for QNN AOT compilation, including: + * - head_dim: Dimension of each attention head + * - max_cache_length: Maximum KV cache length + * - end_of_text_token_id: Token ID for end of text + */ +struct Llama3Config : protected ConfigFile { + Llama3Config() = default; + + explicit Llama3Config(const std::string& file_path) : ConfigFile(file_path) { + // Init all + vocab_size = data()["vocab_size"]; + hidden_size = data()["hidden_size"]; + intermediate_size = data()["intermediate_size"]; + num_hidden_layers = data()["num_hidden_layers"]; + num_attention_heads = data()["num_attention_heads"]; + num_key_value_heads = data()["num_key_value_heads"]; + hidden_act = data()["hidden_act"]; + max_position_embeddings = data()["max_position_embeddings"]; + rms_norm_eps = data()["rms_norm_eps"]; + rope_theta = data()["rope_theta"]; + attention_bias = data()["attention_bias"]; + + // Handle head_dim - compute from hidden_size/num_attention_heads if not provided + if (data().contains("head_dim")) { + head_dim = data()["head_dim"]; + } else { + head_dim = hidden_size / num_attention_heads; + } + + // Handle default values for optional parameters + if (num_key_value_heads == 0) { num_key_value_heads = num_attention_heads; } + + // Token IDs + bos_token_id = data()["bos_token_id"]; + eos_token_id = data()["eos_token_id"]; + + // End of text token - use eos_token_id as default + if (data().contains("end_of_text_token_id")) { + end_of_text_token_id = data()["end_of_text_token_id"]; + } else { + end_of_text_token_id = eos_token_id; + } + + tie_word_embeddings = data()["tie_word_embeddings"]; + max_cache_length = data()["max_cache_length"]; + + linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]); + } + + // Model architecture parameters + int32_t vocab_size = 128256; // Llama 3.2 vocabulary size + int32_t hidden_size = 3072; // Llama 3.2 3B hidden size + int32_t head_dim = 128; // Head dimension + int32_t intermediate_size = 8192; // FFN intermediate size + int32_t num_hidden_layers = 28; // Number of transformer layers + int32_t num_attention_heads = 24; // Number of attention heads + int32_t num_key_value_heads = 8; // Number of KV heads (GQA) + std::string hidden_act = "silu"; // Activation function + int32_t max_position_embeddings = 131072; // Max sequence length + + // Normalization and RoPE + float rms_norm_eps = 1e-5; // RMSNorm epsilon + float rope_theta = 500000.0; // RoPE base frequency + + // Attention + bool attention_bias = false; // Whether to use bias in attention + + // Token IDs + int64_t bos_token_id = 128000; // Begin of sequence token + int64_t eos_token_id = 128009; // End of sequence token + int32_t end_of_text_token_id = 128009; // End of text token for generation + + // Word embedding + bool tie_word_embeddings = true; // Tie input/output embeddings + + // Cache configuration + int32_t max_cache_length = 2048; // Maximum KV cache length + + // Linear implementation type for quantization + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::llama3 diff --git a/examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp b/examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp new file mode 100644 index 00000000..a129cd3b --- /dev/null +++ b/examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp @@ -0,0 +1,508 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "configuration_llama3.hpp" + +namespace mllm::models::llama3 { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class LlamaMLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + LlamaMLP() = default; + LlamaMLP(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +class LlamaAttention final : public nn::Module { + nn::Conv2D q_proj_; + nn::Conv2D k_proj_; + nn::Conv2D v_proj_; + nn::Conv2D o_proj_; + // Llama does NOT have RMSNorm after QK projection (unlike Qwen3) + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + LlamaAttention() = default; + + LlamaAttention(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, CONV2D_PROPERTY); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + // clang-format on + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + // [B, S, D] + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // [B, S, H * D] + auto query_states = q_proj_(hidden_states); + auto key_states = k_proj_(hidden_states); + auto value_states = v_proj_(hidden_states); + + query_states = ptq::QDQ(this, query_states, "q_proj_output_qdq"); + key_states = ptq::QDQ(this, key_states, "k_proj_output_qdq"); + + // [B, H, S, D] + query_states = query_states.view({1, -1, num_attention_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + key_states = key_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + value_states = value_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + + // Llama does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE + + // [B, H, S, D] + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); + + // De-quantization and quantization again + key_states = key_states.to(kFloat32); + key_states = key_states.to(kUInt8PerTensorSym); + key_states = ptq::QDQ_KV(this, key_states, "k_cast_to_int8_qdq"); + + // [B, H, D, S] + key_states = key_states.transpose(2, 3); + + // Handle KV Cache + value_states = ptq::QDQ(this, value_states, "v_cast_to_int16_qdq"); + value_states = value_states.to(kFloat32); + value_states = value_states.to(kUInt8PerTensorSym); + value_states = ptq::QDQ_KV(this, value_states, "v_cast_to_int8_qdq"); + + auto kh = nn::functional::concat({past_key, key_states}, -1); // [B, H, D, S] + auto vh = nn::functional::concat({past_value, value_states}, 2); // [B, H, S, D] + + // Repeat + kh = kh.repeat(num_key_value_groups_, 1); + vh = vh.repeat(num_key_value_groups_, 1); + + // Attn + auto attn = ptq::QDQ(this, nn::functional::matmul(query_states, kh), "qk_matmul_output_qdq"); + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq"); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq"); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); + auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + return {y, key_states, value_states}; + } + + int layer_idx_; +}; + +class LlamaDecoder final : public nn::Module { + public: + int layer_idx_; + LlamaAttention self_attn_; + LlamaMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + LlamaDecoder() = default; + + LlamaDecoder(const std::string& name, const Llama3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class LlamaText final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + LlamaText() = default; + + LlamaText(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class LlamaForCausalLM : public ARGeneration, public nn::Module { + public: + explicit LlamaForCausalLM(const Llama3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const Llama3Config& cfg; + LlamaText llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +} // namespace mllm::models::llama3 diff --git a/examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp b/examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp new file mode 100644 index 00000000..a26ebef1 --- /dev/null +++ b/examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp @@ -0,0 +1,788 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// The optimization splits large Q/K/V projections into per-head projections, +// allowing QNN to optimize each head separately, reducing AOT compilation time +// and improving HTP performance. + +#pragma once + +#include "mllm/core/TensorStorage.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "configuration_llama3.hpp" + +namespace mllm::models::llama3::sha { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class LlamaMLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + LlamaMLP() = default; + LlamaMLP(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +// ============================================================================ +// Single Head Attention (SHA) Implementation +// ============================================================================ +// +// This class implements SHA where each attention head has its own separate +// Conv2D projection, instead of one large MHA projection that processes all +// heads at once. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Note: Llama does NOT have RMSNorm after Q/K projection (unlike Qwen3) + +class LlamaAttentionSHA final : public nn::Module { + // Per-head Q projections: num_attention_heads Conv2D(hidden_size, head_dim) + std::vector q_projs_; + // Per-head K projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector k_projs_; + // Per-head V projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector v_projs_; + // Single O projection remains unchanged (concatenated heads -> hidden_size) + nn::Conv2D o_proj_; + + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + LlamaAttentionSHA() = default; + + LlamaAttentionSHA(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // Register per-head Q projections + for (int h = 0; h < num_attention_heads_; ++h) { + q_projs_.emplace_back(reg("q_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head K projections + for (int h = 0; h < num_key_value_heads_; ++h) { + k_projs_.emplace_back(reg("k_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head V projections + for (int h = 0; h < num_key_value_heads_; ++h) { + v_projs_.emplace_back(reg("v_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // O projection remains the same (combines all heads) + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + const auto& past_key = inputs[4]; // [B, num_kv_heads, D, S] + const auto& past_value = inputs[5]; // [B, num_kv_heads, S, D] + + // [B, S, D] - shared QDQ for input to all Q/K/V projections + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // ======================================================================== + // Per-head Q/K/V Projections + // ======================================================================== + // This is the key SHA optimization: instead of one large projection for all + // heads, we have separate smaller projections per head. + + // Compute per-head Q projections: each outputs [1, 1, S, head_dim] + std::vector query_states_per_head; + for (int h = 0; h < num_attention_heads_; ++h) { + auto q_h = q_projs_[h](hidden_states); + q_h = q_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + query_states_per_head.push_back(q_h); + } + + // Compute per-head K projections + std::vector key_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto k_h = k_projs_[h](hidden_states); + k_h = k_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + key_states_per_head.push_back(k_h); + } + + // Compute per-head V projections + std::vector value_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto v_h = v_projs_[h](hidden_states); + v_h = v_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + value_states_per_head.push_back(v_h); + } + + // ======================================================================== + // Reshape and Transpose for RoPE + // ======================================================================== + // Llama does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE after reshaping to [B, H, S, D] format + // Each head tensor is [1, 1, S, head_dim], need to reshape to [1, 1, S, head_dim] for RoPE + // (The shape is already correct, but we need to ensure QDQ is applied) + + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + + // Apply QDQ and RoPE per Q head + // Each query_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + query_states_per_head[h] = ptq::QDQ(this, query_states_per_head[h], "q_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + query_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, query_states_per_head[h] * cos, "q_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(query_states_per_head[h], this, "q_rope_neg_half_qdq_h" + h_str) * sin, + "q_rope_mul_1_output_qdq_h" + h_str), + "q_rope_add_0_output_qdq_h" + h_str); + } + + // Apply QDQ and RoPE per K head + // Each key_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + key_states_per_head[h] = ptq::QDQ(this, key_states_per_head[h], "k_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + key_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, key_states_per_head[h] * cos, "k_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(key_states_per_head[h], this, "k_rope_neg_half_qdq_h" + h_str) * sin, + "k_rope_mul_1_output_qdq_h" + h_str), + "k_rope_add_0_output_qdq_h" + h_str); + } + + // ======================================================================== + // KV Cache Processing per head + // ======================================================================== + + std::vector new_key_per_head; + std::vector new_value_per_head; + std::vector key_cache_per_head; + std::vector value_cache_per_head; + + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + + // K: De-quantize and re-quantize to int8 + auto k_h = key_states_per_head[h].to(kFloat32); + k_h = k_h.to(kUInt8PerTensorSym); + k_h = ptq::QDQ_KV(this, k_h, "k_cast_to_int8_qdq_h" + h_str); + k_h = k_h.transpose(2, 3); // [B, 1, D, S] + + // V: Quantize to int16 then int8 + auto v_h = ptq::QDQ(this, value_states_per_head[h], "v_cast_to_int16_qdq_h" + h_str); + v_h = v_h.to(kFloat32); + v_h = v_h.to(kUInt8PerTensorSym); + v_h = ptq::QDQ_KV(this, v_h, "v_cast_to_int8_qdq_h" + h_str); + + new_key_per_head.push_back(k_h); + new_value_per_head.push_back(v_h); + + // Slice past cache for this head + auto past_k_h = past_key.slice({kAll, {h, h + 1}, kAll, kAll}, true); + auto past_v_h = past_value.slice({kAll, {h, h + 1}, kAll, kAll}, true); + + // Concat current with past + key_cache_per_head.push_back(nn::functional::concat({past_k_h, k_h}, -1)); + value_cache_per_head.push_back(nn::functional::concat({past_v_h, v_h}, 2)); + } + + // ======================================================================== + // Per-head Attention Computation + // ======================================================================== + // Each Q head computes attention with its corresponding KV head (GQA support) + // For GQA, multiple Q heads share the same KV head + + std::vector attn_outputs; + + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + int kv_head_idx = h / num_key_value_groups_; + + const auto& q_h = query_states_per_head[h]; + const auto& kh = key_cache_per_head[kv_head_idx]; + const auto& vh = value_cache_per_head[kv_head_idx]; + + // QK^T + auto attn = ptq::QDQ(this, nn::functional::matmul(q_h, kh), "qk_matmul_output_qdq_h" + h_str); + + // Scale + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq_h" + h_str); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq_h" + h_str); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq_h" + h_str); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq_h" + h_str); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq_h" + h_str); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq_h" + h_str); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq_h" + h_str); + + // Output: attn @ V + auto y_h = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq_h" + h_str); + attn_outputs.push_back(y_h); + } + + // ======================================================================== + // Concatenate and Output Projection + // ======================================================================== + + // Concat all head outputs: [B, num_heads, S, D] + auto y = nn::functional::concat(attn_outputs, 1); + + // Reshape and apply O projection + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + // Concat new keys and values back to original format + auto new_key = nn::functional::concat(new_key_per_head, 1); + auto new_value = nn::functional::concat(new_value_per_head, 1); + + return {y, new_key, new_value}; + } + + int layer_idx_; +}; + +class LlamaDecoderSHA final : public nn::Module { + public: + int layer_idx_; + LlamaAttentionSHA self_attn_; + LlamaMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + LlamaDecoderSHA() = default; + + LlamaDecoderSHA(const std::string& name, const Llama3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class LlamaTextSHA final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + LlamaTextSHA() = default; + + LlamaTextSHA(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class LlamaForCausalLM_SHA : public ARGeneration, public nn::Module { + public: + explicit LlamaForCausalLM_SHA(const Llama3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const Llama3Config& cfg; + LlamaTextSHA llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +// ============================================================================ +// Weight Slicing Utilities for SHA +// ============================================================================ +// +// These functions are used during the compile phase to slice the original +// MHA weights into per-head SHA weights. +// +// Note: Llama does NOT have q_norm and k_norm (RMSNorm), so we don't need +// to slice those parameters. + +/** + * @brief Prepares the parameter file by slicing MHA weights into SHA weights. + * + * This function takes the original parameter file with MHA weights and creates + * new per-head weights for the SHA model. + * + * Original weight layout for Conv2D: [out_channels, in_channels, 1, 1] + * - q_proj.weight: [num_heads * head_dim, hidden_size, 1, 1] + * - k_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * - v_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * + * For LPBQ quantization, also need to slice: + * - scale1: flattened scale for block quantization + * - scale2: flattened scale for block quantization + * + * SHA weight layout: + * - q_proj.{h}.weight: [head_dim, hidden_size, 1, 1] for h in [0, num_heads) + * - q_proj.{h}.scale1: sliced scale for head h + * - q_proj.{h}.scale2: sliced scale for head h + * - Similar for k_proj and v_proj + */ +inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Llama3Config& cfg) { + int num_heads = cfg.num_attention_heads; + int num_kv_heads = cfg.num_key_value_heads; + int head_dim = cfg.head_dim; + int num_layers = cfg.num_hidden_layers; + + // Helper lambda to slice and push Conv2D params (weight, scale1, scale2) + // For LPBQ, scale1 and scale2 are flattened along the output channel dimension + // Scale size per head = total_scale_size / num_heads_for_this_proj + auto sliceAndPushConv2DParams = [&](const std::string& orig_name_prefix, const std::string& new_name_prefix, + int total_out_channels, int out_channels_per_head, int num_splits) { + // Process weight: HWIO format [H=1, W=1, In_channels, Out_channels] + // For q_proj: [1, 1, hidden_size, num_heads * head_dim] + // Slice on the last dimension (Out_channels) + std::string orig_weight_name = orig_name_prefix + ".weight"; + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + for (int h = 0; h < num_splits; ++h) { + std::string new_weight_name = new_name_prefix + "." + std::to_string(h) + ".weight"; + int start_idx = h * out_channels_per_head; + int end_idx = (h + 1) * out_channels_per_head; + // HWIO format: slice on dim 3 (Out_channels) + auto sliced = orig_weight.slice({kAll, kAll, kAll, {start_idx, end_idx}}, false); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); + } + } + + // Process scale1: flattened, size = total_out_channels / block_size (or similar) + // Slice index: (total_scale_size / num_splits) * h + std::string orig_scale1_name = orig_name_prefix + ".scale1"; + if (params->has(orig_scale1_name)) { + auto orig_scale1 = params->pull(orig_scale1_name); + int total_scale_size = orig_scale1.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale1_name = new_name_prefix + "." + std::to_string(h) + ".scale1"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale1.slice({{start_idx, end_idx}}, false); + params->push(new_scale1_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale1_name)); + } + } + + // Process scale2: flattened, same logic as scale1 + std::string orig_scale2_name = orig_name_prefix + ".scale2"; + if (params->has(orig_scale2_name)) { + auto orig_scale2 = params->pull(orig_scale2_name); + int total_scale_size = orig_scale2.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale2_name = new_name_prefix + "." + std::to_string(h) + ".scale2"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale2.slice({{start_idx, end_idx}}, false); + params->push(new_scale2_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale2_name)); + } + } + }; + + for (int layer = 0; layer < num_layers; ++layer) { + std::string layer_prefix = "model.layers." + std::to_string(layer) + ".self_attn."; + + // Process Q projection: split into num_heads parts + sliceAndPushConv2DParams(layer_prefix + "q_proj", layer_prefix + "q_proj", num_heads * head_dim, head_dim, num_heads); + + // Process K projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "k_proj", layer_prefix + "k_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // Process V projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "v_proj", layer_prefix + "v_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // ======================================================================== + // Duplicate QDQ parameters for each head + // ======================================================================== + // The original MHA uses shared QDQ params for all heads. For SHA, we + // duplicate these to per-head versions using "_h{N}" suffix naming. + // This allows each head to have its own quantization parameters. + + auto copyQDQParams = [&](const std::string& base_name, const std::string& new_base_name, int count) { + std::string scale_name = layer_prefix + base_name + ".fake_quant.scale"; + std::string zp_name = layer_prefix + base_name + ".fake_quant.zero_point"; + + if (params->has(scale_name)) { + auto scale = params->pull(scale_name); + auto zp = params->pull(zp_name); + + for (int h = 0; h < count; ++h) { + std::string new_scale_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.scale"; + std::string new_zp_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.zero_point"; + // QDQ scale/zp are typically scalar or small tensors, clone to ensure contiguous + params->push(new_scale_name, scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + params->push(new_zp_name, zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); + } + } + }; + + // Copy QDQ params for Q-related nodes (per Q head) + copyQDQParams("q_proj_output_qdq", "q_proj_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_0_output_qdq", "q_rope_mul_0_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_1_output_qdq", "q_rope_mul_1_output_qdq_h", num_heads); + copyQDQParams("q_rope_neg_half_qdq", "q_rope_neg_half_qdq_h", num_heads); + copyQDQParams("q_rope_add_0_output_qdq", "q_rope_add_0_output_qdq_h", num_heads); + + // Copy QDQ params for K-related nodes (per KV head) + copyQDQParams("k_proj_output_qdq", "k_proj_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_0_output_qdq", "k_rope_mul_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_1_output_qdq", "k_rope_mul_1_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_neg_half_qdq", "k_rope_neg_half_qdq_h", num_kv_heads); + copyQDQParams("k_rope_add_0_output_qdq", "k_rope_add_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_cast_to_int8_qdq", "k_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for V-related nodes (per KV head) + copyQDQParams("v_cast_to_int16_qdq", "v_cast_to_int16_qdq_h", num_kv_heads); + copyQDQParams("v_cast_to_int8_qdq", "v_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for attention computation (per Q head) + copyQDQParams("qk_matmul_output_qdq", "qk_matmul_output_qdq_h", num_heads); + copyQDQParams("scaling_qdq", "scaling_qdq_h", num_heads); + copyQDQParams("mul_0_output_qdq", "mul_0_output_qdq_h", num_heads); + copyQDQParams("reduce_min_output_qdq", "reduce_min_output_qdq_h", num_heads); + copyQDQParams("neg_20_qdq", "neg_20_qdq_h", num_heads); + copyQDQParams("minus_0_output_qdq", "minus_0_output_qdq_h", num_heads); + copyQDQParams("where_attn_qdq", "where_attn_qdq_h", num_heads); + copyQDQParams("softmax_output_qdq", "softmax_output_qdq_h", num_heads); + copyQDQParams("attn_value_matmul_output_qdq", "attn_value_matmul_output_qdq_h", num_heads); + } +} + +} // namespace mllm::models::llama3::sha diff --git a/examples/llama_qnn_aot/qnn_aot_cfg_3B.json b/examples/llama_qnn_aot/qnn_aot_cfg_3B.json new file mode 100644 index 00000000..97240f5a --- /dev/null +++ b/examples/llama_qnn_aot/qnn_aot_cfg_3B.json @@ -0,0 +1,51 @@ +{ + "target_machine": { + "htp_arch": "V75", + "htp_chipset": "SM8650", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 28, + "builtin_llm_pass": { + "model": "llama", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } +} diff --git a/examples/qwen2_qnn_aot/CMakeLists.txt b/examples/qwen2_qnn_aot/CMakeLists.txt new file mode 100644 index 00000000..4db6131c --- /dev/null +++ b/examples/qwen2_qnn_aot/CMakeLists.txt @@ -0,0 +1,14 @@ +# AOT targets run on x86 +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + add_executable(mllm-qwen2-aot-c compile.cpp) + target_link_libraries(mllm-qwen2-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-qwen2-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) + + add_executable(mllm-qwen2-aot-c-sha compile_sha.cpp) + target_link_libraries(mllm-qwen2-aot-c-sha PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-qwen2-aot-c-sha PRIVATE ${MLLM_INCLUDE_DIR}) +endif() + +add_executable(mllm-qwen2-aot-runner aot_run.cpp) +target_link_libraries(mllm-qwen2-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) +target_include_directories(mllm-qwen2-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_qnn_aot/aot_run.cpp b/examples/qwen2_qnn_aot/aot_run.cpp new file mode 100644 index 00000000..14d2dadf --- /dev/null +++ b/examples/qwen2_qnn_aot/aot_run.cpp @@ -0,0 +1,61 @@ +#include +#include +#include +#include +#include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" + +using mllm::Argparse; +using namespace mllm::qnn::aot; // NOLINT + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model").help("Model path").def("qwen2_qnn.mllm"); + auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); + auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); + auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); + auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); + auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); + + Argparse::parse(argc, argv); + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + mllm::initQnnBackend(model_path.get()); + + auto qwen2_cfg = mllm::models::qwen3::Qwen3Config(config_path.get()); + + RunnerConfig config; + config.num_layers = qwen2_cfg.num_hidden_layers; + config.num_heads = qwen2_cfg.num_attention_heads; + config.head_dim = qwen2_cfg.head_dim; + config.vocab_size = qwen2_cfg.vocab_size; + config.context_len = 1024; + config.ar_len = ar_len.get(); + + auto tokenizer = mllm::models::qwen3::Qwen3Tokenizer(tokenizer_path.get()); + + auto input_tensor = tokenizer.convertMessage({.prompt = "hello"}); + + input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + + // DBG: + mllm::print(input_tensor["sequence"].shape()); + mllm::print(input_tensor["sequence"]); + + Runner runner(config, &tokenizer); + if (!runner.load()) { + std::cerr << "Failed to load model\n"; + return 1; + } + + runner.generate( + input_tensor["sequence"], gen_len.get(), [](const std::string& token) { std::cout << token << std::flush; }, true); + std::cout << "\n"; + + return 0; +}); diff --git a/examples/qwen2_qnn_aot/compile.cpp b/examples/qwen2_qnn_aot/compile.cpp new file mode 100644 index 00000000..28850196 --- /dev/null +++ b/examples/qwen2_qnn_aot/compile.cpp @@ -0,0 +1,159 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "modeling_qwen2_qnn_aot.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); + auto model = mllm::models::qwen2::Qwen2ForCausalLM(model_cfg); + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "qwen2-lpbq.bin"); +}); diff --git a/examples/qwen2_qnn_aot/compile_sha.cpp b/examples/qwen2_qnn_aot/compile_sha.cpp new file mode 100644 index 00000000..50aa9b5e --- /dev/null +++ b/examples/qwen2_qnn_aot/compile_sha.cpp @@ -0,0 +1,196 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Usage: +// ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json + +#include +#include +#include +#include +#include +#include + +#include "modeling_qwen2_qnn_aot_sha.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); + + // Load original parameters + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + + // ============================================================================ + // Key Step: Prepare SHA parameters by slicing MHA weights + // ============================================================================ + // This is the critical step that transforms MHA weights into SHA weights. + // For each Q/K/V projection, we slice the weight matrix into per-head pieces. + // + // Original: q_proj.weight [num_heads * head_dim, hidden_size, 1, 1] + // SHA: q_proj.{h}.weight [head_dim, hidden_size, 1, 1] for each head h + // + mllm::print("Preparing SHA parameters (slicing MHA weights)..."); + mllm::models::qwen2::sha::prepareParametersForSHA(params, model_cfg); + mllm::print("SHA parameters prepared."); + + // Create SHA model + auto model = mllm::models::qwen2::sha::Qwen2ForCausalLM_SHA(model_cfg); + + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=32)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_sha_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=1)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_sha_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "qwen2-lpbq-sha.bin"); + + mllm::print("SHA compilation completed successfully!"); + mllm::print("Output files:"); + mllm::print(" - qwen2_qnn_aot_sha_32.mir (IR dump for seq=32)"); + mllm::print(" - qwen2_qnn_aot_sha_1.mir (IR dump for seq=1)"); + mllm::print(" - qwen2-lpbq-sha.bin (QNN context)"); +}); diff --git a/examples/qwen2_qnn_aot/config_1.5B.json b/examples/qwen2_qnn_aot/config_1.5B.json new file mode 100644 index 00000000..e04d581b --- /dev/null +++ b/examples/qwen2_qnn_aot/config_1.5B.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "Qwen2ForCausalLM" + ], + "bos_token_id": 151643, + "eos_token_id": 151645, + "attention_bias": true, + "attention_dropout": 0.0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1536, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 32768, + "max_window_layers": 0, + "model_type": "qwen2", + "num_attention_heads": 12, + "num_hidden_layers": 28, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "max_cache_length": 2048, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/qwen2_qnn_aot/config_3B.json b/examples/qwen2_qnn_aot/config_3B.json new file mode 100644 index 00000000..e532f0d2 --- /dev/null +++ b/examples/qwen2_qnn_aot/config_3B.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "Qwen2ForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "attention_bias": true, + "hidden_act": "silu", + "hidden_size": 2048, + "head_dim": 128, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 32768, + "max_window_layers": 70, + "model_type": "qwen2", + "num_attention_heads": 16, + "num_hidden_layers": 36, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "max_cache_length": 2048, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/qwen2_qnn_aot/config_7B.json b/examples/qwen2_qnn_aot/config_7B.json new file mode 100644 index 00000000..8673b310 --- /dev/null +++ b/examples/qwen2_qnn_aot/config_7B.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "Qwen2ForCausalLM" + ], + "bos_token_id": 151643, + "eos_token_id": 151645, + "attention_bias": true, + "hidden_size": 3584, + "head_dim": 128, + "intermediate_size": 9728, + "num_attention_heads": 28, + "num_key_value_heads": 4, + "num_hidden_layers": 32, + "max_position_embeddings": 32768, + "rms_norm_eps": 1e-06, + "vocab_size": 151936, + "max_cache_length": 2048, + "rope_theta": 1000000.0, + "tie_word_embeddings": true, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp new file mode 100644 index 00000000..26d57e67 --- /dev/null +++ b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp @@ -0,0 +1,508 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +namespace mllm::models::qwen2 { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class Qwen2MLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + Qwen2MLP() = default; + Qwen2MLP(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +class Qwen2Attention final : public nn::Module { + nn::Conv2D q_proj_; + nn::Conv2D k_proj_; + nn::Conv2D v_proj_; + nn::Conv2D o_proj_; + // Qwen2 does NOT have RMSNorm after QK projection (unlike Qwen3) + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + Qwen2Attention() = default; + + Qwen2Attention(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, CONV2D_PROPERTY); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + // clang-format on + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + // [B, S, D] + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // [B, S, H * D] + auto query_states = q_proj_(hidden_states); + auto key_states = k_proj_(hidden_states); + auto value_states = v_proj_(hidden_states); + + query_states = ptq::QDQ(this, query_states, "q_proj_output_qdq"); + key_states = ptq::QDQ(this, key_states, "k_proj_output_qdq"); + + // [B, H, S, D] + query_states = query_states.view({1, -1, num_attention_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + key_states = key_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + value_states = value_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + + // Qwen2 does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE + + // [B, H, S, D] + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); + + // De-quantization and quantization again + key_states = key_states.to(kFloat32); + key_states = key_states.to(kUInt8PerTensorSym); + key_states = ptq::QDQ_KV(this, key_states, "k_cast_to_int8_qdq"); + + // [B, H, D, S] + key_states = key_states.transpose(2, 3); + + // Handle KV Cache + value_states = ptq::QDQ(this, value_states, "v_cast_to_int16_qdq"); + value_states = value_states.to(kFloat32); + value_states = value_states.to(kUInt8PerTensorSym); + value_states = ptq::QDQ_KV(this, value_states, "v_cast_to_int8_qdq"); + + auto kh = nn::functional::concat({past_key, key_states}, -1); // [B, H, D, S] + auto vh = nn::functional::concat({past_value, value_states}, 2); // [B, H, S, D] + + // Repeat + kh = kh.repeat(num_key_value_groups_, 1); + vh = vh.repeat(num_key_value_groups_, 1); + + // Attn + auto attn = ptq::QDQ(this, nn::functional::matmul(query_states, kh), "qk_matmul_output_qdq"); + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq"); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq"); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); + auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + return {y, key_states, value_states}; + } + + int layer_idx_; +}; + +class Qwen2Decoder final : public nn::Module { + public: + int layer_idx_; + Qwen2Attention self_attn_; + Qwen2MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2Decoder() = default; + + Qwen2Decoder(const std::string& name, const qwen3::Qwen3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class Qwen2Text final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + Qwen2Text() = default; + + Qwen2Text(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class Qwen2ForCausalLM : public ARGeneration, public nn::Module { + public: + explicit Qwen2ForCausalLM(const qwen3::Qwen3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const qwen3::Qwen3Config& cfg; + Qwen2Text llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +} // namespace mllm::models::qwen2 diff --git a/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp new file mode 100644 index 00000000..db69c601 --- /dev/null +++ b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp @@ -0,0 +1,788 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// The optimization splits large Q/K/V projections into per-head projections, +// allowing QNN to optimize each head separately, reducing AOT compilation time +// and improving HTP performance. + +#pragma once + +#include "mllm/core/TensorStorage.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +namespace mllm::models::qwen2::sha { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class Qwen2MLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + Qwen2MLP() = default; + Qwen2MLP(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +// ============================================================================ +// Single Head Attention (SHA) Implementation +// ============================================================================ +// +// This class implements SHA where each attention head has its own separate +// Conv2D projection, instead of one large MHA projection that processes all +// heads at once. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Note: Qwen2 does NOT have RMSNorm after Q/K projection (unlike Qwen3) + +class Qwen2AttentionSHA final : public nn::Module { + // Per-head Q projections: num_attention_heads Conv2D(hidden_size, head_dim) + std::vector q_projs_; + // Per-head K projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector k_projs_; + // Per-head V projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector v_projs_; + // Single O projection remains unchanged (concatenated heads -> hidden_size) + nn::Conv2D o_proj_; + + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + Qwen2AttentionSHA() = default; + + Qwen2AttentionSHA(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // Register per-head Q projections + for (int h = 0; h < num_attention_heads_; ++h) { + q_projs_.emplace_back(reg("q_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head K projections + for (int h = 0; h < num_key_value_heads_; ++h) { + k_projs_.emplace_back(reg("k_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head V projections + for (int h = 0; h < num_key_value_heads_; ++h) { + v_projs_.emplace_back(reg("v_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // O projection remains the same (combines all heads) + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + const auto& past_key = inputs[4]; // [B, num_kv_heads, D, S] + const auto& past_value = inputs[5]; // [B, num_kv_heads, S, D] + + // [B, S, D] - shared QDQ for input to all Q/K/V projections + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // ======================================================================== + // Per-head Q/K/V Projections + // ======================================================================== + // This is the key SHA optimization: instead of one large projection for all + // heads, we have separate smaller projections per head. + + // Compute per-head Q projections: each outputs [1, 1, S, head_dim] + std::vector query_states_per_head; + for (int h = 0; h < num_attention_heads_; ++h) { + auto q_h = q_projs_[h](hidden_states); + q_h = q_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + query_states_per_head.push_back(q_h); + } + + // Compute per-head K projections + std::vector key_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto k_h = k_projs_[h](hidden_states); + k_h = k_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + key_states_per_head.push_back(k_h); + } + + // Compute per-head V projections + std::vector value_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto v_h = v_projs_[h](hidden_states); + v_h = v_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + value_states_per_head.push_back(v_h); + } + + // ======================================================================== + // Reshape and Transpose for RoPE + // ======================================================================== + // Qwen2 does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE after reshaping to [B, H, S, D] format + // Each head tensor is [1, 1, S, head_dim], need to reshape to [1, 1, S, head_dim] for RoPE + // (The shape is already correct, but we need to ensure QDQ is applied) + + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + + // Apply QDQ and RoPE per Q head + // Each query_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + query_states_per_head[h] = ptq::QDQ(this, query_states_per_head[h], "q_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + query_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, query_states_per_head[h] * cos, "q_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(query_states_per_head[h], this, "q_rope_neg_half_qdq_h" + h_str) * sin, + "q_rope_mul_1_output_qdq_h" + h_str), + "q_rope_add_0_output_qdq_h" + h_str); + } + + // Apply QDQ and RoPE per K head + // Each key_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + key_states_per_head[h] = ptq::QDQ(this, key_states_per_head[h], "k_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + key_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, key_states_per_head[h] * cos, "k_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(key_states_per_head[h], this, "k_rope_neg_half_qdq_h" + h_str) * sin, + "k_rope_mul_1_output_qdq_h" + h_str), + "k_rope_add_0_output_qdq_h" + h_str); + } + + // ======================================================================== + // KV Cache Processing per head + // ======================================================================== + + std::vector new_key_per_head; + std::vector new_value_per_head; + std::vector key_cache_per_head; + std::vector value_cache_per_head; + + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + + // K: De-quantize and re-quantize to int8 + auto k_h = key_states_per_head[h].to(kFloat32); + k_h = k_h.to(kUInt8PerTensorSym); + k_h = ptq::QDQ_KV(this, k_h, "k_cast_to_int8_qdq_h" + h_str); + k_h = k_h.transpose(2, 3); // [B, 1, D, S] + + // V: Quantize to int16 then int8 + auto v_h = ptq::QDQ(this, value_states_per_head[h], "v_cast_to_int16_qdq_h" + h_str); + v_h = v_h.to(kFloat32); + v_h = v_h.to(kUInt8PerTensorSym); + v_h = ptq::QDQ_KV(this, v_h, "v_cast_to_int8_qdq_h" + h_str); + + new_key_per_head.push_back(k_h); + new_value_per_head.push_back(v_h); + + // Slice past cache for this head + auto past_k_h = past_key.slice({kAll, {h, h + 1}, kAll, kAll}, true); + auto past_v_h = past_value.slice({kAll, {h, h + 1}, kAll, kAll}, true); + + // Concat current with past + key_cache_per_head.push_back(nn::functional::concat({past_k_h, k_h}, -1)); + value_cache_per_head.push_back(nn::functional::concat({past_v_h, v_h}, 2)); + } + + // ======================================================================== + // Per-head Attention Computation + // ======================================================================== + // Each Q head computes attention with its corresponding KV head (GQA support) + // For GQA, multiple Q heads share the same KV head + + std::vector attn_outputs; + + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + int kv_head_idx = h / num_key_value_groups_; + + const auto& q_h = query_states_per_head[h]; + const auto& kh = key_cache_per_head[kv_head_idx]; + const auto& vh = value_cache_per_head[kv_head_idx]; + + // QK^T + auto attn = ptq::QDQ(this, nn::functional::matmul(q_h, kh), "qk_matmul_output_qdq_h" + h_str); + + // Scale + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq_h" + h_str); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq_h" + h_str); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq_h" + h_str); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq_h" + h_str); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq_h" + h_str); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq_h" + h_str); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq_h" + h_str); + + // Output: attn @ V + auto y_h = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq_h" + h_str); + attn_outputs.push_back(y_h); + } + + // ======================================================================== + // Concatenate and Output Projection + // ======================================================================== + + // Concat all head outputs: [B, num_heads, S, D] + auto y = nn::functional::concat(attn_outputs, 1); + + // Reshape and apply O projection + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + // Concat new keys and values back to original format + auto new_key = nn::functional::concat(new_key_per_head, 1); + auto new_value = nn::functional::concat(new_value_per_head, 1); + + return {y, new_key, new_value}; + } + + int layer_idx_; +}; + +class Qwen2DecoderSHA final : public nn::Module { + public: + int layer_idx_; + Qwen2AttentionSHA self_attn_; + Qwen2MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2DecoderSHA() = default; + + Qwen2DecoderSHA(const std::string& name, const qwen3::Qwen3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class Qwen2TextSHA final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + Qwen2TextSHA() = default; + + Qwen2TextSHA(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class Qwen2ForCausalLM_SHA : public ARGeneration, public nn::Module { + public: + explicit Qwen2ForCausalLM_SHA(const qwen3::Qwen3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const qwen3::Qwen3Config& cfg; + Qwen2TextSHA llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +// ============================================================================ +// Weight Slicing Utilities for SHA +// ============================================================================ +// +// These functions are used during the compile phase to slice the original +// MHA weights into per-head SHA weights. +// +// Note: Qwen2 does NOT have q_norm and k_norm (RMSNorm), so we don't need +// to slice those parameters. + +/** + * @brief Prepares the parameter file by slicing MHA weights into SHA weights. + * + * This function takes the original parameter file with MHA weights and creates + * new per-head weights for the SHA model. + * + * Original weight layout for Conv2D: [out_channels, in_channels, 1, 1] + * - q_proj.weight: [num_heads * head_dim, hidden_size, 1, 1] + * - k_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * - v_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * + * For LPBQ quantization, also need to slice: + * - scale1: flattened scale for block quantization + * - scale2: flattened scale for block quantization + * + * SHA weight layout: + * - q_proj.{h}.weight: [head_dim, hidden_size, 1, 1] for h in [0, num_heads) + * - q_proj.{h}.scale1: sliced scale for head h + * - q_proj.{h}.scale2: sliced scale for head h + * - Similar for k_proj and v_proj + */ +inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const qwen3::Qwen3Config& cfg) { + int num_heads = cfg.num_attention_heads; + int num_kv_heads = cfg.num_key_value_heads; + int head_dim = cfg.head_dim; + int num_layers = cfg.num_hidden_layers; + + // Helper lambda to slice and push Conv2D params (weight, scale1, scale2) + // For LPBQ, scale1 and scale2 are flattened along the output channel dimension + // Scale size per head = total_scale_size / num_heads_for_this_proj + auto sliceAndPushConv2DParams = [&](const std::string& orig_name_prefix, const std::string& new_name_prefix, + int total_out_channels, int out_channels_per_head, int num_splits) { + // Process weight: HWIO format [H=1, W=1, In_channels, Out_channels] + // For q_proj: [1, 1, hidden_size, num_heads * head_dim] + // Slice on the last dimension (Out_channels) + std::string orig_weight_name = orig_name_prefix + ".weight"; + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + for (int h = 0; h < num_splits; ++h) { + std::string new_weight_name = new_name_prefix + "." + std::to_string(h) + ".weight"; + int start_idx = h * out_channels_per_head; + int end_idx = (h + 1) * out_channels_per_head; + // HWIO format: slice on dim 3 (Out_channels) + auto sliced = orig_weight.slice({kAll, kAll, kAll, {start_idx, end_idx}}, false); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); + } + } + + // Process scale1: flattened, size = total_out_channels / block_size (or similar) + // Slice index: (total_scale_size / num_splits) * h + std::string orig_scale1_name = orig_name_prefix + ".scale1"; + if (params->has(orig_scale1_name)) { + auto orig_scale1 = params->pull(orig_scale1_name); + int total_scale_size = orig_scale1.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale1_name = new_name_prefix + "." + std::to_string(h) + ".scale1"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale1.slice({{start_idx, end_idx}}, false); + params->push(new_scale1_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale1_name)); + } + } + + // Process scale2: flattened, same logic as scale1 + std::string orig_scale2_name = orig_name_prefix + ".scale2"; + if (params->has(orig_scale2_name)) { + auto orig_scale2 = params->pull(orig_scale2_name); + int total_scale_size = orig_scale2.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale2_name = new_name_prefix + "." + std::to_string(h) + ".scale2"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale2.slice({{start_idx, end_idx}}, false); + params->push(new_scale2_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale2_name)); + } + } + }; + + for (int layer = 0; layer < num_layers; ++layer) { + std::string layer_prefix = "model.layers." + std::to_string(layer) + ".self_attn."; + + // Process Q projection: split into num_heads parts + sliceAndPushConv2DParams(layer_prefix + "q_proj", layer_prefix + "q_proj", num_heads * head_dim, head_dim, num_heads); + + // Process K projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "k_proj", layer_prefix + "k_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // Process V projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "v_proj", layer_prefix + "v_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // ======================================================================== + // Duplicate QDQ parameters for each head + // ======================================================================== + // The original MHA uses shared QDQ params for all heads. For SHA, we + // duplicate these to per-head versions using "_h{N}" suffix naming. + // This allows each head to have its own quantization parameters. + + auto copyQDQParams = [&](const std::string& base_name, const std::string& new_base_name, int count) { + std::string scale_name = layer_prefix + base_name + ".fake_quant.scale"; + std::string zp_name = layer_prefix + base_name + ".fake_quant.zero_point"; + + if (params->has(scale_name)) { + auto scale = params->pull(scale_name); + auto zp = params->pull(zp_name); + + for (int h = 0; h < count; ++h) { + std::string new_scale_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.scale"; + std::string new_zp_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.zero_point"; + // QDQ scale/zp are typically scalar or small tensors, clone to ensure contiguous + params->push(new_scale_name, scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + params->push(new_zp_name, zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); + } + } + }; + + // Copy QDQ params for Q-related nodes (per Q head) + copyQDQParams("q_proj_output_qdq", "q_proj_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_0_output_qdq", "q_rope_mul_0_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_1_output_qdq", "q_rope_mul_1_output_qdq_h", num_heads); + copyQDQParams("q_rope_neg_half_qdq", "q_rope_neg_half_qdq_h", num_heads); + copyQDQParams("q_rope_add_0_output_qdq", "q_rope_add_0_output_qdq_h", num_heads); + + // Copy QDQ params for K-related nodes (per KV head) + copyQDQParams("k_proj_output_qdq", "k_proj_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_0_output_qdq", "k_rope_mul_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_1_output_qdq", "k_rope_mul_1_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_neg_half_qdq", "k_rope_neg_half_qdq_h", num_kv_heads); + copyQDQParams("k_rope_add_0_output_qdq", "k_rope_add_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_cast_to_int8_qdq", "k_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for V-related nodes (per KV head) + copyQDQParams("v_cast_to_int16_qdq", "v_cast_to_int16_qdq_h", num_kv_heads); + copyQDQParams("v_cast_to_int8_qdq", "v_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for attention computation (per Q head) + copyQDQParams("qk_matmul_output_qdq", "qk_matmul_output_qdq_h", num_heads); + copyQDQParams("scaling_qdq", "scaling_qdq_h", num_heads); + copyQDQParams("mul_0_output_qdq", "mul_0_output_qdq_h", num_heads); + copyQDQParams("reduce_min_output_qdq", "reduce_min_output_qdq_h", num_heads); + copyQDQParams("neg_20_qdq", "neg_20_qdq_h", num_heads); + copyQDQParams("minus_0_output_qdq", "minus_0_output_qdq_h", num_heads); + copyQDQParams("where_attn_qdq", "where_attn_qdq_h", num_heads); + copyQDQParams("softmax_output_qdq", "softmax_output_qdq_h", num_heads); + copyQDQParams("attn_value_matmul_output_qdq", "attn_value_matmul_output_qdq_h", num_heads); + } +} + +} // namespace mllm::models::qwen2::sha diff --git a/examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json b/examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json new file mode 100644 index 00000000..3ddf11c9 --- /dev/null +++ b/examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json @@ -0,0 +1,51 @@ +{ + "target_machine": { + "htp_arch": "V75", + "htp_chipset": "SM8650", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 28, + "builtin_llm_pass": { + "model": "qwen2", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } +} diff --git a/examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json b/examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json new file mode 100644 index 00000000..d765567f --- /dev/null +++ b/examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json @@ -0,0 +1,51 @@ +{ + "target_machine": { + "htp_arch": "V75", + "htp_chipset": "SM8650", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 36, + "builtin_llm_pass": { + "model": "qwen2", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } +} diff --git a/examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json b/examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json new file mode 100644 index 00000000..85f66abd --- /dev/null +++ b/examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json @@ -0,0 +1,51 @@ +{ + "target_machine": { + "htp_arch": "V79", + "htp_chipset": "SM8750", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 32, + "builtin_llm_pass": { + "model": "qwen2", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } +} diff --git a/examples/qwen3_qnn_aot/CMakeLists.txt b/examples/qwen3_qnn_aot/CMakeLists.txt index 18041bdc..6556d055 100644 --- a/examples/qwen3_qnn_aot/CMakeLists.txt +++ b/examples/qwen3_qnn_aot/CMakeLists.txt @@ -3,8 +3,13 @@ if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) add_executable(mllm-qwen3-aot-c compile.cpp) target_link_libraries(mllm-qwen3-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) target_include_directories(mllm-qwen3-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) + + # SHA (Single Head Attention) version - MHA to SHA optimization + add_executable(mllm-qwen3-aot-sha-c compile_sha.cpp) + target_link_libraries(mllm-qwen3-aot-sha-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-qwen3-aot-sha-c PRIVATE ${MLLM_INCLUDE_DIR}) endif() add_executable(mllm-qwen3-aot-runner aot_run.cpp) target_link_libraries(mllm-qwen3-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) -target_include_directories(mllm-qwen3-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) \ No newline at end of file +target_include_directories(mllm-qwen3-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen3_qnn_aot/compile.cpp b/examples/qwen3_qnn_aot/compile.cpp index f47b0dee..cc813fe3 100644 --- a/examples/qwen3_qnn_aot/compile.cpp +++ b/examples/qwen3_qnn_aot/compile.cpp @@ -20,8 +20,8 @@ MLLM_MAIN({ Argparse::parse(argc, argv); - constexpr int N = 32; - constexpr int CL = 1024; + int N = 32; + int CL = 1024; if (help.isSet()) { Argparse::printHelp(); @@ -39,38 +39,98 @@ MLLM_MAIN({ auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); // Add params for causal mask { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); } model.load(params); - // Sequence: [B, N] - // past_key_i: [B, H, D, CL-N] for each layer i - // past_value_i: [B, H, CL-N, D] for each layer i - // causal_mask: [B, 1, N, CL] - auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); - auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. - // NOTE: force set causal mask to UInt16Asy - // NOTE: Attach scale and zero point to causal mask { - causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); - causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); - causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); - } + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } - // Create KV cache inputs for all layers - std::unordered_map trace_inputs; - trace_inputs["sequence"] = sequence; - trace_inputs["causal_mask"] = causal_mask; + auto ir = model.trace(trace_inputs, {}); - for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { - auto past_key_name = "past_key_" + std::to_string(i); - auto past_value_name = "past_value_" + std::to_string(i); + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); - // clang-format off + mllm::redirect("qwen3_qnn_aot_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off trace_inputs[past_key_name] = mllm::Tensor::empty({ 1, model_cfg.num_key_value_heads, @@ -84,20 +144,17 @@ MLLM_MAIN({ trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - // clang-format on - } - - auto ir = model.trace(trace_inputs, {}); + // clang-format on + } - // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", - mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + auto ir = model.trace(trace_inputs, {}); - mllm::ir::PassManager pm(ir["model"]); - pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); - pm.run(); + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); - mllm::redirect("qwen3_qnn_aot.mir", [&]() { mllm::print(ir["model"]); }); + mllm::redirect("qwen3_qnn_aot_1.mir", [&]() { mllm::print(ir["model"]); }); + } qnn_aot_env.saveContext("context.0", "qwen3-1.7B-lpbq.bin"); }); diff --git a/examples/qwen3_qnn_aot/compile_sha.cpp b/examples/qwen3_qnn_aot/compile_sha.cpp new file mode 100644 index 00000000..8e6dd232 --- /dev/null +++ b/examples/qwen3_qnn_aot/compile_sha.cpp @@ -0,0 +1,196 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Usage: +// ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json + +#include +#include +#include +#include +#include +#include + +#include "modeling_qwen_qnn_aot_sha.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); + + // Load original parameters + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + + // ============================================================================ + // Key Step: Prepare SHA parameters by slicing MHA weights + // ============================================================================ + // This is the critical step that transforms MHA weights into SHA weights. + // For each Q/K/V projection, we slice the weight matrix into per-head pieces. + // + // Original: q_proj.weight [num_heads * head_dim, hidden_size, 1, 1] + // SHA: q_proj.{h}.weight [head_dim, hidden_size, 1, 1] for each head h + // + mllm::print("Preparing SHA parameters (slicing MHA weights)..."); + mllm::models::qwen3::sha::prepareParametersForSHA(params, model_cfg); + mllm::print("SHA parameters prepared."); + + // Create SHA model + auto model = mllm::models::qwen3::sha::Qwen3ForCausalLM_SHA(model_cfg); + + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=32)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen3_qnn_aot_sha_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=1)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen3_qnn_aot_sha_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "qwen3-1.7B-lpbq-sha.bin"); + + mllm::print("SHA compilation completed successfully!"); + mllm::print("Output files:"); + mllm::print(" - qwen3_qnn_aot_sha_32.mir (IR dump for seq=32)"); + mllm::print(" - qwen3_qnn_aot_sha_1.mir (IR dump for seq=1)"); + mllm::print(" - qwen3-1.7B-lpbq-sha.bin (QNN context)"); +}); diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp new file mode 100644 index 00000000..272535d5 --- /dev/null +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp @@ -0,0 +1,881 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// The optimization splits large Q/K/V projections into per-head projections, +// allowing QNN to optimize each head separately, reducing AOT compilation time +// and improving HTP performance. + +#pragma once + +#include "mllm/core/TensorStorage.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +namespace mllm::models::qwen3::sha { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class Qwen3MLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + Qwen3MLP() = default; + Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +// ============================================================================ +// Single Head Attention (SHA) Implementation +// ============================================================================ +// +// This class implements SHA where each attention head has its own separate +// Conv2D projection, instead of one large MHA projection that processes all +// heads at once. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// + +class Qwen3AttentionSHA final : public nn::Module { + // Per-head Q projections: num_attention_heads Conv2D(hidden_size, head_dim) + std::vector q_projs_; + // Per-head K projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector k_projs_; + // Per-head V projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector v_projs_; + // Single O projection remains unchanged (concatenated heads -> hidden_size) + nn::Conv2D o_proj_; + + // Per-head RMSNorm for Q + std::vector rms_norm_q_; + // Per-head RMSNorm for K (shared across GQA groups) + std::vector rms_norm_k_; + + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + Qwen3AttentionSHA() = default; + + Qwen3AttentionSHA(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // Register per-head Q projections + for (int h = 0; h < num_attention_heads_; ++h) { + q_projs_.emplace_back(reg("q_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head K projections + for (int h = 0; h < num_key_value_heads_; ++h) { + k_projs_.emplace_back(reg("k_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head V projections + for (int h = 0; h < num_key_value_heads_; ++h) { + v_projs_.emplace_back(reg("v_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // O projection remains the same (combines all heads) + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + + // Per-head Q RMSNorm + for (int h = 0; h < num_attention_heads_; ++h) { + rms_norm_q_.emplace_back(reg("q_norm." + std::to_string(h), cfg.rms_norm_eps)); + } + + // Per-head K RMSNorm (for KV heads) + for (int h = 0; h < num_key_value_heads_; ++h) { + rms_norm_k_.emplace_back(reg("k_norm." + std::to_string(h), cfg.rms_norm_eps)); + } + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + const auto& past_key = inputs[4]; // [B, num_kv_heads, D, S] + const auto& past_value = inputs[5]; // [B, num_kv_heads, S, D] + + // [B, S, D] - shared QDQ for input to all Q/K/V projections + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // ======================================================================== + // Per-head Q/K/V Projections + // ======================================================================== + // This is the key SHA optimization: instead of one large projection for all + // heads, we have separate smaller projections per head. + + // Compute per-head Q projections: each outputs [1, 1, S, head_dim] + std::vector query_states_per_head; + for (int h = 0; h < num_attention_heads_; ++h) { + auto q_h = q_projs_[h](hidden_states); + q_h = q_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + query_states_per_head.push_back(q_h); + } + + // Compute per-head K projections + std::vector key_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto k_h = k_projs_[h](hidden_states); + k_h = k_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + key_states_per_head.push_back(k_h); + } + + // Compute per-head V projections + std::vector value_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto v_h = v_projs_[h](hidden_states); + v_h = v_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + value_states_per_head.push_back(v_h); + } + + // ======================================================================== + // Per-head RMSNorm and RoPE + // ======================================================================== + + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + + // Apply RMSNorm and RoPE per Q head + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + query_states_per_head[h] = rms_norm_q_[h](ptq::QDQ(this, query_states_per_head[h], "q_norm_input_qdq_h" + h_str)); + query_states_per_head[h] = ptq::QDQ(this, query_states_per_head[h], "q_norm_output_qdq_h" + h_str); + + // Apply RoPE + query_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, query_states_per_head[h] * cos, "q_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(query_states_per_head[h], this, "q_rope_neg_half_qdq_h" + h_str) * sin, + "q_rope_mul_1_output_qdq_h" + h_str), + "q_rope_add_0_output_qdq_h" + h_str); + } + + // Apply RMSNorm and RoPE per K head + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + key_states_per_head[h] = rms_norm_k_[h](ptq::QDQ(this, key_states_per_head[h], "k_norm_input_qdq_h" + h_str)); + key_states_per_head[h] = ptq::QDQ(this, key_states_per_head[h], "k_norm_output_qdq_h" + h_str); + + // Apply RoPE + key_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, key_states_per_head[h] * cos, "k_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(key_states_per_head[h], this, "k_rope_neg_half_qdq_h" + h_str) * sin, + "k_rope_mul_1_output_qdq_h" + h_str), + "k_rope_add_0_output_qdq_h" + h_str); + } + + // ======================================================================== + // KV Cache Processing per head + // ======================================================================== + + std::vector new_key_per_head; + std::vector new_value_per_head; + std::vector key_cache_per_head; + std::vector value_cache_per_head; + + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + + // K: De-quantize and re-quantize to int8 + auto k_h = key_states_per_head[h].to(kFloat32); + k_h = k_h.to(kUInt8PerTensorSym); + k_h = ptq::QDQ_KV(this, k_h, "k_cast_to_int8_qdq_h" + h_str); + k_h = k_h.transpose(2, 3); // [B, 1, D, S] + + // V: Quantize to int16 then int8 + auto v_h = ptq::QDQ(this, value_states_per_head[h], "v_cast_to_int16_qdq_h" + h_str); + v_h = v_h.to(kFloat32); + v_h = v_h.to(kUInt8PerTensorSym); + v_h = ptq::QDQ_KV(this, v_h, "v_cast_to_int8_qdq_h" + h_str); + + new_key_per_head.push_back(k_h); + new_value_per_head.push_back(v_h); + + // Slice past cache for this head + auto past_k_h = past_key.slice({kAll, {h, h + 1}, kAll, kAll}, true); + auto past_v_h = past_value.slice({kAll, {h, h + 1}, kAll, kAll}, true); + + // Concat current with past + key_cache_per_head.push_back(nn::functional::concat({past_k_h, k_h}, -1)); + value_cache_per_head.push_back(nn::functional::concat({past_v_h, v_h}, 2)); + } + + // ======================================================================== + // Per-head Attention Computation + // ======================================================================== + // Each Q head computes attention with its corresponding KV head (GQA support) + + std::vector attn_outputs; + + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + int kv_head_idx = h / num_key_value_groups_; + + const auto& q_h = query_states_per_head[h]; + const auto& kh = key_cache_per_head[kv_head_idx]; + const auto& vh = value_cache_per_head[kv_head_idx]; + + // QK^T + auto attn = ptq::QDQ(this, nn::functional::matmul(q_h, kh), "qk_matmul_output_qdq_h" + h_str); + + // Scale + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq_h" + h_str); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq_h" + h_str); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq_h" + h_str); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq_h" + h_str); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq_h" + h_str); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq_h" + h_str); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq_h" + h_str); + + // Output: attn @ V + auto y_h = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq_h" + h_str); + attn_outputs.push_back(y_h); + } + + // ======================================================================== + // Concatenate and Output Projection + // ======================================================================== + + // Concat all head outputs: [B, num_heads, S, D] + auto y = nn::functional::concat(attn_outputs, 1); + + // Reshape and apply O projection + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + // Concat new keys and values back to original format + auto new_key = nn::functional::concat(new_key_per_head, 1); + auto new_value = nn::functional::concat(new_value_per_head, 1); + + return {y, new_key, new_value}; + } + + int layer_idx_; +}; + +class Qwen3DecoderSHA final : public nn::Module { + public: + int layer_idx_; + Qwen3AttentionSHA self_attn_; + Qwen3MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen3DecoderSHA() = default; + + Qwen3DecoderSHA(const std::string& name, const Qwen3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class Qwen3TextSHA final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + Qwen3TextSHA() = default; + + Qwen3TextSHA(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class Qwen3ForCausalLM_SHA : public ARGeneration, public nn::Module { + public: + explicit Qwen3ForCausalLM_SHA(const Qwen3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const Qwen3Config& cfg; + Qwen3TextSHA llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +// ============================================================================ +// Weight Slicing Utilities for SHA +// ============================================================================ +// +// These functions are used during the compile phase to slice the original +// MHA weights into per-head SHA weights. +// + +/** + * @brief Prepares the parameter file by slicing MHA weights into SHA weights. + * + * This function takes the original parameter file with MHA weights and creates + * new per-head weights for the SHA model. + * + * Original weight layout for Conv2D: [out_channels, in_channels, 1, 1] + * - q_proj.weight: [num_heads * head_dim, hidden_size, 1, 1] + * - k_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * - v_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * + * For LPBQ quantization, also need to slice: + * - scale1: flattened scale for block quantization + * - scale2: flattened scale for block quantization + * + * SHA weight layout: + * - q_proj.{h}.weight: [head_dim, hidden_size, 1, 1] for h in [0, num_heads) + * - q_proj.{h}.scale1: sliced scale for head h + * - q_proj.{h}.scale2: sliced scale for head h + * - Similar for k_proj and v_proj + */ +inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qwen3Config& cfg) { + int num_heads = cfg.num_attention_heads; + int num_kv_heads = cfg.num_key_value_heads; + int head_dim = cfg.head_dim; + int num_layers = cfg.num_hidden_layers; + + // Helper lambda to slice and push Conv2D params (weight, scale1, scale2) + // For LPBQ, scale1 and scale2 are flattened along the output channel dimension + // Scale size per head = total_scale_size / num_heads_for_this_proj + auto sliceAndPushConv2DParams = [&](const std::string& orig_name_prefix, const std::string& new_name_prefix, + int total_out_channels, int out_channels_per_head, int num_splits) { + // Process weight: HWIO format [H=1, W=1, In_channels, Out_channels] + // For q_proj: [1, 1, hidden_size, num_heads * head_dim] + // Slice on the last dimension (Out_channels) + std::string orig_weight_name = orig_name_prefix + ".weight"; + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + for (int h = 0; h < num_splits; ++h) { + std::string new_weight_name = new_name_prefix + "." + std::to_string(h) + ".weight"; + int start_idx = h * out_channels_per_head; + int end_idx = (h + 1) * out_channels_per_head; + // HWIO format: slice on dim 3 (Out_channels) + auto sliced = orig_weight.slice({kAll, kAll, kAll, {start_idx, end_idx}}, false); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); + } + } + + // Process scale1: flattened, size = total_out_channels / block_size (or similar) + // Slice index: (total_scale_size / num_splits) * h + std::string orig_scale1_name = orig_name_prefix + ".scale1"; + if (params->has(orig_scale1_name)) { + auto orig_scale1 = params->pull(orig_scale1_name); + int total_scale_size = orig_scale1.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale1_name = new_name_prefix + "." + std::to_string(h) + ".scale1"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale1.slice({{start_idx, end_idx}}, false); + params->push(new_scale1_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale1_name)); + } + } + + // Process scale2: flattened, same logic as scale1 + std::string orig_scale2_name = orig_name_prefix + ".scale2"; + if (params->has(orig_scale2_name)) { + auto orig_scale2 = params->pull(orig_scale2_name); + int total_scale_size = orig_scale2.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale2_name = new_name_prefix + "." + std::to_string(h) + ".scale2"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale2.slice({{start_idx, end_idx}}, false); + params->push(new_scale2_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale2_name)); + } + } + }; + + for (int layer = 0; layer < num_layers; ++layer) { + std::string layer_prefix = "model.layers." + std::to_string(layer) + ".self_attn."; + + // Process Q projection: split into num_heads parts + sliceAndPushConv2DParams(layer_prefix + "q_proj", layer_prefix + "q_proj", num_heads * head_dim, head_dim, num_heads); + + // Process K projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "k_proj", layer_prefix + "k_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // Process V projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "v_proj", layer_prefix + "v_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // Process Q norm params (per head) + // RMSNorm has: weight (needs slicing), scale (scalar, copy), zero_point (scalar, copy) + { + std::string orig_weight_name = layer_prefix + "q_norm.weight"; + std::string orig_scale_name = layer_prefix + "q_norm.scale"; + std::string orig_zp_name = layer_prefix + "q_norm.zero_point"; + + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); // [num_heads * head_dim] + + // scale and zp are scalars, just need to copy + Tensor orig_scale, orig_zp; + bool has_scale = params->has(orig_scale_name); + bool has_zp = params->has(orig_zp_name); + if (has_scale) orig_scale = params->pull(orig_scale_name); + if (has_zp) orig_zp = params->pull(orig_zp_name); + + for (int h = 0; h < num_heads; ++h) { + std::string h_str = std::to_string(h); + // Weight: slice per head + std::string new_weight_name = layer_prefix + "q_norm." + h_str + ".weight"; // NOLINT + auto sliced = orig_weight.slice({{h * head_dim, (h + 1) * head_dim}}, false); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); + + // Scale: copy (scalar) + if (has_scale) { + std::string new_scale_name = layer_prefix + "q_norm." + h_str + ".scale"; // NOLINT + params->push(new_scale_name, orig_scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + } + + // Zero point: copy (scalar) + if (has_zp) { + std::string new_zp_name = layer_prefix + "q_norm." + h_str + ".zero_point"; // NOLINT + params->push(new_zp_name, orig_zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); + } + } + } + } + + // Process K norm params (per KV head) + { + std::string orig_weight_name = layer_prefix + "k_norm.weight"; + std::string orig_scale_name = layer_prefix + "k_norm.scale"; + std::string orig_zp_name = layer_prefix + "k_norm.zero_point"; + + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + Tensor orig_scale, orig_zp; + bool has_scale = params->has(orig_scale_name); + bool has_zp = params->has(orig_zp_name); + if (has_scale) orig_scale = params->pull(orig_scale_name); + if (has_zp) orig_zp = params->pull(orig_zp_name); + + for (int h = 0; h < num_kv_heads; ++h) { + std::string h_str = std::to_string(h); + // Weight: slice per head + std::string new_weight_name = layer_prefix + "k_norm." + h_str + ".weight"; // NOLINT + auto sliced = orig_weight.slice({{h * head_dim, (h + 1) * head_dim}}, false); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); + + // Scale: copy (scalar) + if (has_scale) { + std::string new_scale_name = layer_prefix + "k_norm." + h_str + ".scale"; // NOLINT + params->push(new_scale_name, orig_scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + } + + // Zero point: copy (scalar) + if (has_zp) { + std::string new_zp_name = layer_prefix + "k_norm." + h_str + ".zero_point"; // NOLINT + params->push(new_zp_name, orig_zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); + } + } + } + } + + // ======================================================================== + // Duplicate QDQ parameters for each head + // ======================================================================== + // The original MHA uses shared QDQ params for all heads. For SHA, we + // duplicate these to per-head versions using "_h{N}" suffix naming. + // This allows each head to have its own quantization parameters. + + auto copyQDQParams = [&](const std::string& base_name, const std::string& new_base_name, int count) { + std::string scale_name = layer_prefix + base_name + ".fake_quant.scale"; + std::string zp_name = layer_prefix + base_name + ".fake_quant.zero_point"; + + if (params->has(scale_name)) { + auto scale = params->pull(scale_name); + auto zp = params->pull(zp_name); + + for (int h = 0; h < count; ++h) { + std::string new_scale_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.scale"; + std::string new_zp_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.zero_point"; + // QDQ scale/zp are typically scalar or small tensors, clone to ensure contiguous + params->push(new_scale_name, scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + params->push(new_zp_name, zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); + } + } + }; + + // Copy QDQ params for Q-related nodes (per Q head) + copyQDQParams("q_norm_input_qdq", "q_norm_input_qdq_h", num_heads); + copyQDQParams("q_norm_output_qdq", "q_norm_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_0_output_qdq", "q_rope_mul_0_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_1_output_qdq", "q_rope_mul_1_output_qdq_h", num_heads); + copyQDQParams("q_rope_neg_half_qdq", "q_rope_neg_half_qdq_h", num_heads); + copyQDQParams("q_rope_add_0_output_qdq", "q_rope_add_0_output_qdq_h", num_heads); + + // Copy QDQ params for K-related nodes (per KV head) + copyQDQParams("k_norm_input_qdq", "k_norm_input_qdq_h", num_kv_heads); + copyQDQParams("k_norm_output_qdq", "k_norm_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_0_output_qdq", "k_rope_mul_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_1_output_qdq", "k_rope_mul_1_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_neg_half_qdq", "k_rope_neg_half_qdq_h", num_kv_heads); + copyQDQParams("k_rope_add_0_output_qdq", "k_rope_add_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_cast_to_int8_qdq", "k_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for V-related nodes (per KV head) + copyQDQParams("v_cast_to_int16_qdq", "v_cast_to_int16_qdq_h", num_kv_heads); + copyQDQParams("v_cast_to_int8_qdq", "v_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for attention computation (per Q head) + copyQDQParams("qk_matmul_output_qdq", "qk_matmul_output_qdq_h", num_heads); + copyQDQParams("scaling_qdq", "scaling_qdq_h", num_heads); + copyQDQParams("mul_0_output_qdq", "mul_0_output_qdq_h", num_heads); + copyQDQParams("reduce_min_output_qdq", "reduce_min_output_qdq_h", num_heads); + copyQDQParams("neg_20_qdq", "neg_20_qdq_h", num_heads); + copyQDQParams("minus_0_output_qdq", "minus_0_output_qdq_h", num_heads); + copyQDQParams("where_attn_qdq", "where_attn_qdq_h", num_heads); + copyQDQParams("softmax_output_qdq", "softmax_output_qdq_h", num_heads); + copyQDQParams("attn_value_matmul_output_qdq", "attn_value_matmul_output_qdq_h", num_heads); + } +} + +} // namespace mllm::models::qwen3::sha diff --git a/mllm/backends/qnn/QNNBackend.cpp b/mllm/backends/qnn/QNNBackend.cpp index 3fdd5dbd..1a891cb6 100644 --- a/mllm/backends/qnn/QNNBackend.cpp +++ b/mllm/backends/qnn/QNNBackend.cpp @@ -29,7 +29,7 @@ QNNBackend::QNNBackend() : Backend(kQNN, createQNNAllocator()) { QNNViewOpFactory, QNNRMSNormOpFactory, QNNTransposeOpFactory, QNNX2XOpFactory, QNNCastTypeOpFactory, QNNParamOpFactory, QNNSiLUOpFactory, QNNEmbeddingOpFactory>(); - QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_VERBOSE; // default QNN log level + QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_ERROR; // default QNN log level profilingLevel_ = ProfilingLevel::OFF; debug_ = false; // when set true, NATIVE tensor will be regared as APP_READ tensor @@ -98,8 +98,8 @@ QNNPerf::QNNPerf(const QNN_INTERFACE_VER_TYPE* qnnInterface) { .powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE, .setSleepLatency = 1, // True to consider Latency parameter otherwise False .sleepLatency = 40, // set dsp sleep latency ranges 10-65535 micro sec, refer hexagon sdk - .setSleepDisable = 1, // True to consider sleep disable/enable parameter otherwise False - .sleepDisable = 1, // True to disable sleep, False to re-enable sleep + .setSleepDisable = 0, // True to consider sleep disable/enable parameter otherwise False + .sleepDisable = 0, // True to disable sleep, False to re-enable sleep .setBusParams = 1, // True to consider Bus parameter otherwise False .busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER, .busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER, diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index e7478105..0f67bab5 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -182,6 +182,9 @@ void QnnAOTNodeTensor::setupComplexTensorQuantization(const ir::tensor::TensorVa std::vector scale_offsets(num_scale_offsets); MLLM_RT_ASSERT_EQ(num_scale_offsets, cfg->scale_level_1_fp.size(-1)); MLLM_RT_ASSERT_EQ(cfg->scale_level_0_int.dtype(), kUInt8); + MLLM_RT_ASSERT_EQ(cfg->scale_level_1_fp.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(cfg->scale_level_0_int.rank(), 1); + MLLM_RT_ASSERT_EQ(cfg->scale_level_1_fp.rank(), 1); for (int i = 0; i < num_scale_offsets; ++i) { scale_offsets[i].scale = cfg->scale_level_1_fp.at({i}); scale_offsets[i].offset = 0; @@ -270,10 +273,11 @@ QnnAOTGraph::QnnAOTGraph(QNN_INTERFACE_VER_TYPE& qnnInterface, Qnn_BackendHandle // Short Depth Conv On HMX Off QnnHtpGraph_CustomConfig_t* p_custom_config = nullptr; - p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t)); - p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF; - p_custom_config->shortDepthConvOnHmxOff = true; - htp_graph_configs.push_back(static_cast(p_custom_config)); + // FIXME: @chenghuaWang The code below will make llm inference slow!!! + // p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t)); + // p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF; + // p_custom_config->shortDepthConvOnHmxOff = true; + // htp_graph_configs.push_back(static_cast(p_custom_config)); // Fold Relu Activation Into Conv Off p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t)); diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 444abf57..7e2a6322 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -645,34 +645,37 @@ bool LLMQuantRecipeConcatPattern::isMatch(const mllm::ir::op_ptr_t& op) { } bool LLMQuantRecipeConcatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) { - // Current support concat two Tensor. Inherent first tensor's Quant Spec. + // Support concat with multiple inputs. Inherit first tensor's Quant Spec. auto concat_ir = node->cast_(); - auto i_0 = *(node->inputs().begin()); // t1 - auto i_1 = *(std::next(node->inputs().begin())); // t2 - auto o_0 = *(node->outputs().begin()); // to1 + auto o_0 = *(node->outputs().begin()); // to1 - if (concat_ir->inputs().size() != 2) { - MLLM_WARN("Current support concat two Tensor. Inherent first tensor's setting."); + if (concat_ir->inputs().empty()) { + MLLM_WARN("Concat op has no inputs."); return false; } - // Create quant_recipe if not present - if (!i_0->getAttr("quant_recipe")) { - auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); - i_0->setAttr("quant_recipe", i_0_spec); - } - if (!i_1->getAttr("quant_recipe")) { - auto i_1_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_1->cast_()); - i_1->setAttr("quant_recipe", i_1_spec); + auto i_0 = *(node->inputs().begin()); // First input + + // Create quant_recipe for all inputs if not present + for (auto input : node->inputs()) { + if (!input->getAttr("quant_recipe")) { + auto input_spec = genSimpleQuantizationSpecAttr(writer.getContext(), input->cast_()); + input->setAttr("quant_recipe", input_spec); + } } + // Output inherits first tensor's quant_recipe o_0->setAttr("quant_recipe", i_0->getAttr("quant_recipe")); auto annotation_attr = writer.create(); - annotation_attr->annotation_.inputs.emplace_back( - i_0->getAttr("quant_recipe")->cast_()->spec_); - annotation_attr->annotation_.inputs.emplace_back( - i_1->getAttr("quant_recipe")->cast_()->spec_); + + // Add quant_recipe for all inputs + for (auto input : node->inputs()) { + annotation_attr->annotation_.inputs.emplace_back( + input->getAttr("quant_recipe")->cast_()->spec_); + } + + // Add quant_recipe for output annotation_attr->annotation_.outputs.emplace_back( o_0->getAttr("quant_recipe")->cast_()->spec_); diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index b1d46f90..43e591fe 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -47,7 +47,7 @@ void solveLinearWeight(const ir::IRContext::ptr_t& ctx, const ParameterFile::ptr auto weight = pf->pull(mllm_op->getName() + ".weight"); // FIXME weight maybe error, Check qnn eats int8 or uint8. Here weight using int8 to store int4. - checkTypeLimits(weight, -8, 7); // Int4 + checkTypeLimits(weight, 0, 15); // Int4 checkTypeLimits(scale1, 0, 16); // UInt4 this_spec->scale_level_0_int = scale1; diff --git a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp index bee78b1b..f9eae715 100644 --- a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp +++ b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp @@ -126,6 +126,8 @@ int64_t PromptProcessor::prefill(const std::vector& prompt_tokens, i module_->setOutputTensors(output_tensors_); + MLLM_INFO("num_tokens: {}", num_tokens); + while (processed_tokens < num_tokens) { int64_t chunk_size = std::min((int64_t)config_.ar_len, num_tokens - processed_tokens); diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp index 3bbd077d..ae1fafa2 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp @@ -3,6 +3,7 @@ #include #include +#include #include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include "mllm/core/DataTypes.hpp" @@ -10,6 +11,8 @@ #include "mllm/preprocessor/tokenizers/Unicode.hpp" #include "mllm/utils/Common.hpp" #include "mllm/utils/Log.hpp" +#include +#include namespace mllm::qnn::aot { Runner::Runner(const RunnerConfig& config, mllm::preprocessor::AutoTokenizer* tokenizer) @@ -56,16 +59,25 @@ bool Runner::load() { } void Runner::generate(const Tensor& prompt_tokens, int32_t seq_len, - const std::function& token_callback) { + const std::function& token_callback, bool perf) { MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64); int64_t start_pos = 0; std::vector prompt_tokens_i64; prompt_tokens_i64.reserve(prompt_tokens.shape()[1]); + for (int i = 0; i < prompt_tokens.shape()[1]; i++) { prompt_tokens_i64.push_back(prompt_tokens.ptr()[i]); } + // Measure prefill time + std::chrono::high_resolution_clock::time_point prefill_start, prefill_end; + if (perf) { prefill_start = std::chrono::high_resolution_clock::now(); } + + int64_t prefill_token_count = prompt_tokens_i64.size(); int64_t next_token = prompt_processor_->prefill(prompt_tokens_i64, start_pos); + prompt_tokens_i64.push_back(next_token); + + if (perf) { prefill_end = std::chrono::high_resolution_clock::now(); } if (token_callback) { std::wstring wstr = tokenizer_->detokenize(next_token); @@ -73,9 +85,47 @@ void Runner::generate(const Tensor& prompt_tokens, int32_t seq_len, token_callback(str); } - // int64_t cur_pos = prompt_tokens.size(-1); + int64_t cur_pos = prompt_tokens.size(-1); + + // Measure decode time + std::chrono::high_resolution_clock::time_point decode_start, decode_end; + if (perf) { decode_start = std::chrono::high_resolution_clock::now(); } + + int64_t generated_count = token_generator_->generate(prompt_tokens_i64, cur_pos, seq_len, token_callback, false); + + if (perf) { + decode_end = std::chrono::high_resolution_clock::now(); + + // Calculate durations in microseconds + auto prefill_duration = std::chrono::duration_cast(prefill_end - prefill_start).count(); + auto decode_duration = std::chrono::duration_cast(decode_end - decode_start).count(); - // token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback, false); + // Calculate TPS + double prefill_tps = 0.0; + double decode_tps = 0.0; + + if (prefill_duration > 0) { prefill_tps = (double)prefill_token_count / (prefill_duration / 1000000.0); } + + if (decode_duration > 0 && generated_count > 0) { decode_tps = (double)generated_count / (decode_duration / 1000000.0); } + + // Print performance summary + fmt::print(fg(fmt::color::cyan), "\n{:=^50}\n", " Performance Summary "); + fmt::print(fg(fmt::color::white), "{:<20}: ", "Prefill time"); + fmt::print(fg(fmt::color::yellow), "{:>10.2f} μs", (double)prefill_duration); + if (prefill_tps > 0) { fmt::print(fg(fmt::color::white), " ({:>6.2f} tokens/s)", prefill_tps); } + fmt::print("\n"); + + fmt::print(fg(fmt::color::white), "{:<20}: ", "Decode time"); + fmt::print(fg(fmt::color::yellow), "{:>10.2f} μs", (double)decode_duration); + if (decode_tps > 0) { fmt::print(fg(fmt::color::white), " ({:>6.2f} tokens/s)", decode_tps); } + fmt::print("\n"); + + fmt::print(fg(fmt::color::white), "{:<20}: ", "Prefill tokens"); + fmt::print(fg(fmt::color::green), "{:>10}\n", prefill_token_count); + + fmt::print(fg(fmt::color::white), "{:<20}: ", "Decode tokens"); + fmt::print(fg(fmt::color::green), "{:>10}\n", generated_count); + } } } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp index 12a13e67..51ce86c7 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp @@ -22,7 +22,8 @@ class Runner { ~Runner() = default; bool load(); - void generate(const Tensor& prompt_tokens, int32_t seq_len, const std::function& token_callback); + void generate(const Tensor& prompt_tokens, int32_t seq_len, const std::function& token_callback, + bool perf = false); private: RunnerConfig config_; diff --git a/mllm/backends/qnn/aot_rt/TokenGenerator.cpp b/mllm/backends/qnn/aot_rt/TokenGenerator.cpp index d2cbbf43..4e088435 100644 --- a/mllm/backends/qnn/aot_rt/TokenGenerator.cpp +++ b/mllm/backends/qnn/aot_rt/TokenGenerator.cpp @@ -98,10 +98,10 @@ void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { } template -int64_t TokenGenerator::generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, +int64_t TokenGenerator::generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, const std::function& token_callback, bool dump_logits) { int64_t current_pos = start_pos; - uint64_t next_token = tokens.back(); + int64_t next_token = tokens.back(); int64_t generated_count = 0; // Ensure KV cache is arranged for decode (1 token) diff --git a/mllm/backends/qnn/aot_rt/TokenGenerator.hpp b/mllm/backends/qnn/aot_rt/TokenGenerator.hpp index da50836d..b40b9725 100644 --- a/mllm/backends/qnn/aot_rt/TokenGenerator.hpp +++ b/mllm/backends/qnn/aot_rt/TokenGenerator.hpp @@ -25,7 +25,7 @@ class TokenGenerator { virtual const std::vector& get_all_logits(); - virtual int64_t generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, + virtual int64_t generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, const std::function& token_callback, bool dump_logits); protected: diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/backends/qualcomm/transformers/core/qlinear.py index d3bc8150..9e90ba8a 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/backends/qualcomm/transformers/core/qlinear.py @@ -242,6 +242,13 @@ def convert_to_conv2d_deploy_hwio(self): .contiguous() ) + # Packing for Qnn to use + # Qnn's packing will not do x & 0x0F, so we need to do it here. + mask = torch.full( + weight_int4.size(), 0x0F, dtype=torch.int8, device=weight_int4.device + ) + weight_int4 = torch.bitwise_and(mask, weight_int4) + del self.weight self.register_buffer("weight", weight_int4) self.register_buffer("scale1", quantized_scales.flatten()) diff --git a/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py new file mode 100644 index 00000000..119ec04b --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py @@ -0,0 +1,831 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask +from transformers.modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs +from transformers.models.llama.configuration_llama import LlamaConfig + +# Replace linear, rms_norm with: +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, +) +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +logger = logging.get_logger(__name__) + + +class LlamaRMSNorm(QRMSNorm): + def __init__(self, hidden_size, eps=1e-6, quant_bits=16): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__(hidden_size, eps=eps, quant_bits=quant_bits) + + +class LlamaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: LlamaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = QLinearLPBQ( + self.hidden_size, + self.intermediate_size, + bias=config.mlp_bias, + block_size=32, + ) + self.up_proj = QLinearLPBQ( + self.hidden_size, + self.intermediate_size, + bias=config.mlp_bias, + block_size=32, + ) + self.down_proj = QLinearLPBQ( + self.intermediate_size, + self.hidden_size, + bias=config.mlp_bias, + block_size=32, + ) + self.act_fn = ACT2FN[config.hidden_act] + + # QDQ + self.up_proj_input_qdq = ActivationQDQ(bits=16) + self.up_proj_output_qdq = ActivationQDQ(bits=16) + self.gate_proj_output_qdq = ActivationQDQ(bits=16) + self.act_output_qdq = ActivationQDQ(bits=16) + self.down_proj_input_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) + + def forward(self, x): + x = self.up_proj_input_qdq(x) + up_result = self.up_proj_output_qdq(self.up_proj(x)) + gate_result = self.gate_proj_output_qdq(self.gate_proj(x)) + + # SiLU or other activation + gate_result = self.act_output_qdq( + gate_result * self.sigmoid_output_qdq(F.sigmoid(gate_result)) + ) + + o = self.down_proj_input_qdq(gate_result * up_result) + o = self.down_proj(o) + return o + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = QLinearLPBQ( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + block_size=32, + ) + self.k_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + block_size=32, + ) + self.v_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + block_size=32, + ) + self.o_proj = QLinearLPBQ( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + block_size=32, + ) + + # QDQ + self.q_proj_input_qdq = ActivationQDQ(bits=16) + self.k_proj_input_qdq = ActivationQDQ(bits=16) + self.v_proj_input_qdq = ActivationQDQ(bits=16) + + self.q_proj_output_qdq = ActivationQDQ(bits=16) + self.k_proj_output_qdq = ActivationQDQ(bits=16) + + self.q_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.q_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_proj_input_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_proj_input_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + + # In qnn, is uint8 sym. + self.k_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + self.v_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + + self.v_cast_to_int16_qdq = ActivationQDQ(bits=16) + self.qk_matmul_output_qdq = ActivationQDQ(bits=16) + self.scaling_qdq = ActivationQDQ(bits=16) + self.neg_20_qdq = ActivationQDQ(bits=16) + self.reduce_min_output_qdq = ActivationQDQ(bits=16) + self.mul_0_output_qdq = ActivationQDQ(bits=16) + self.minus_0_output_qdq = ActivationQDQ(bits=16) + self.softmax_output_qdq = ActivationQDQ(bits=16) + self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + hidden_states = self.q_proj_input_qdq(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self.q_proj_output_qdq(query_states) + + hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj_output_qdq(key_states) + + hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + query_states = self.q_rope_add_0_output_qdq( + self.q_rope_mul_0_output_qdq(query_states * cos) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) + ) + key_states = self.k_rope_add_0_output_qdq( + self.k_rope_mul_0_output_qdq(key_states * cos) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) + ) + + key_states = self.k_cast_to_int8_qdq(key_states) + value_states = self.v_cast_to_int8_qdq(self.v_cast_to_int16_qdq(value_states)) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = self.mul_0_output_qdq( + self.qk_matmul_output_qdq( + torch.matmul(query_states, key_states.transpose(2, 3)) + ) + * self.scaling_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * self.scaling + ) + ) + + attn_min = self.reduce_min_output_qdq( + torch.amin(attn_weights, dim=-1, keepdim=True) + ) + attn_vv = self.minus_0_output_qdq( + attn_min + + self.neg_20_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) + ) + ) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) + + attn_weights = self.softmax_output_qdq( + nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + ) + attn_output = self.attn_value_matmul_output_qdq( + torch.matmul(attn_weights, value_states) + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LlamaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + + # QDQ + if self.layer_idx != 0: + self.input_layernorm_input_qdq = ActivationQDQ(bits=16) + self.add_0_lhs_input_qdq = ActivationQDQ(bits=16) + self.add_0_output_qdq = ActivationQDQ(bits=16) + self.add_1_lhs_input_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + if self.layer_idx != 0: + hidden_states = self.input_layernorm_input_qdq(hidden_states) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.add_0_output_qdq( + residual + self.add_0_lhs_input_qdq(hidden_states) + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) + return hidden_states + + +@auto_docstring +class LlamaPreTrainedModel(PreTrainedModel): + config: LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": LlamaDecoderLayer, + "attentions": LlamaAttention, + } + + +@auto_docstring +class LlamaModel(LlamaPreTrainedModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = QEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, quant_bits=16 + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Register sin and cos as buffers + self.register_buffer("mllm_max_sin_embedding", None) + self.register_buffer("mllm_max_cos_embedding", None) + self.sin_embedding_input_qdq = ActivationQDQ(bits=16) + self.cos_embedding_input_qdq = ActivationQDQ(bits=16) + self.norm_input_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def convert_rope_for_deploy(self): + sin_scale = self.sin_embedding_input_qdq.fake_quant.scale + sin_zero_point = self.sin_embedding_input_qdq.fake_quant.zero_point + sin_quant_min = self.sin_embedding_input_qdq.fake_quant.quant_min + sin_quant_max = self.sin_embedding_input_qdq.fake_quant.quant_max + + cos_scale = self.cos_embedding_input_qdq.fake_quant.scale + cos_zero_point = self.cos_embedding_input_qdq.fake_quant.zero_point + cos_quant_min = self.cos_embedding_input_qdq.fake_quant.quant_min + cos_quant_max = self.cos_embedding_input_qdq.fake_quant.quant_max + + sin_int = torch.round( + self.mllm_max_sin_embedding / sin_scale + sin_zero_point + ).clamp(sin_quant_min, sin_quant_max) + self.mllm_max_sin_embedding = sin_int.to(torch.uint16) + + cos_int = torch.round( + self.mllm_max_cos_embedding / cos_scale + cos_zero_point + ).clamp(cos_quant_min, cos_quant_max) + self.mllm_max_cos_embedding = cos_int.to(torch.uint16) + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + if self.mllm_max_sin_embedding is None and self.mllm_max_cos_embedding is None: + mllm_qualcomm_max_length = kwargs.get("mllm_qualcomm_max_length", None) + assert mllm_qualcomm_max_length is not None + max_position_ids = torch.arange( + 0, + mllm_qualcomm_max_length, + dtype=position_ids.dtype, + device=position_ids.device, + ).unsqueeze(0) + self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( + hidden_states, max_position_ids + ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( + self.mllm_max_cos_embedding + ) + self.mllm_max_sin_embedding = self.sin_embedding_input_qdq( + self.mllm_max_sin_embedding + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = ( + self.mllm_max_cos_embedding[:, position_ids.squeeze(0), :], + self.mllm_max_sin_embedding[:, position_ids.squeeze(0), :], + ) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(self.norm_input_qdq(hidden_states)) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = QLinearLPBQ( + config.hidden_size, config.vocab_size, bias=False, block_size=32 + ) + self.mllm_qualcomm_max_length = None + + self.lm_head_input_qdq = ActivationQDQ(bits=16) + self.lm_head_output_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def copy_lm_head_weight_from_embed_tokens(self): + if self.config.tie_word_embeddings: + self.lm_head.weight.copy_(self.model.embed_tokens.weight) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + kwargs.update({"mllm_qualcomm_max_length": self.mllm_qualcomm_max_length}) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head( + self.lm_head_input_qdq(hidden_states[:, slice_indices, :]) + ) + logits = self.lm_head_output_qdq(logits) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class LlamaForSequenceClassification( + GenericForSequenceClassification, LlamaPreTrainedModel +): ... + + +class LlamaForQuestionAnswering(GenericForQuestionAnswering, LlamaPreTrainedModel): + base_model_prefix = ( + "transformer" # For BC, where `transformer` was used instead of `model` + ) + + +class LlamaForTokenClassification( + GenericForTokenClassification, LlamaPreTrainedModel +): ... + + +__all__ = [ + "LlamaForCausalLM", + "LlamaModel", + "LlamaPreTrainedModel", + "LlamaForSequenceClassification", + "LlamaForQuestionAnswering", + "LlamaForTokenClassification", +] diff --git a/pymllm/backends/qualcomm/transformers/llama/runner.py b/pymllm/backends/qualcomm/transformers/llama/runner.py new file mode 100644 index 00000000..8aa4627b --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/llama/runner.py @@ -0,0 +1,345 @@ +import torch +from tqdm import tqdm +from modelscope.msdatasets import MsDataset +from transformers import AutoTokenizer +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, + QLinearW8A16_PerChannelSym, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +def recompute_scale_zp(module): + """ + Callback function: Used to forcefully refresh scale and zero_point of all FakeQuantize modules after calibration. + + Problem solved: + When using ConcatObserver, min/max may be updated during forward pass, + but at the end of forward, the scale/zp stored in FakeQuantize's internal buffer are still computed from old min/max. + This function forces a calculate_qparams call to sync the latest parameters to the buffer. + + Usage: + model.apply(recompute_scale_zp) + """ + + # We mainly focus on FakeQuantize modules since they store the scale/zero_point buffers + # Note: model.apply recursively traverses all submodules, so self.fake_quant inside ActivationQDQ will also be visited + if isinstance(module, ActivationQDQ): + observer = module.fake_quant.activation_post_process + + # 2. Check if observer is valid and contains statistics + # We only care about MinMaxObserver or MovingAverageMinMaxObserver that have min_val/max_val + if hasattr(observer, "min_val") and hasattr(observer, "max_val"): + # 3. Check if data is initialized + # If min_val is still the initial inf, this layer hasn't processed data, skip to avoid errors + if observer.min_val.numel() == 0 or observer.max_val.numel() == 0: + return + if ( + torch.isinf(observer.min_val).any() + or torch.isinf(observer.max_val).any() + ): + return + + # 4. Recompute Scale and Zero Point + # calculate_qparams reads the current min_val/max_val from observer (may have been modified by ConcatObserver) + try: + scale, zero_point = observer.calculate_qparams() + except Exception as e: + # Some special Observers (e.g., FixedQParams) may not support recomputation or behave differently, safely skip + print(e) + return + + # 5. Force overwrite the computed results to FakeQuantize's Buffer + # Use copy_ to keep reference unchanged, ensuring the new values are used during export + if ( + hasattr(module.fake_quant, "scale") + and module.fake_quant.scale is not None + ): + # Ensure dimension match (handle per-channel vs per-tensor) + if module.fake_quant.scale.shape != scale.shape: + module.fake_quant.scale.resize_(scale.shape) + module.fake_quant.scale.copy_(scale) + # Try to get the registered name of module scale from _parameters or _buffers + for key, value in module.fake_quant.named_parameters(): + if value is module.fake_quant.scale: + print(f"{module._get_name()}.{key}: {module.scale}") + break + + if ( + hasattr(module.fake_quant, "zero_point") + and module.fake_quant.zero_point is not None + ): + if module.fake_quant.zero_point.shape != zero_point.shape: + module.fake_quant.zero_point.resize_(zero_point.shape) + module.fake_quant.zero_point.copy_(zero_point) + + +def validate_concat_observer_fn(module, results: list, name: str = ""): + """ + Callback function: Validate that all input_observers in ConcatObserver have consistent scale and zero_point. + + Usage: + results = [] + for name, m in model.named_modules(): + validate_concat_observer_fn(m, results, name) + """ + if not isinstance(module, ConcatObserver): + return + + input_observers = module.input_observers + if len(input_observers) == 0: + return + + # Collect scale and zero_point from all observers + scales_zps = [] + for i, observer in enumerate(input_observers): + try: + scale, zp = observer.calculate_qparams() + scales_zps.append(f"[{i}] s={scale.item():.8f} zp={zp.item()}") + except Exception: + scales_zps.append(f"[{i}] failed") + + # Print one line: scale and zp of all inputs for each concat observer + print(f"ConcatObserver [{name}]: {' | '.join(scales_zps)}") + + # Original validation logic + if len(input_observers) <= 1: + return + + # Get scale and zero_point from the first observer as reference + first_observer = input_observers[0] + try: + ref_scale, ref_zp = first_observer.calculate_qparams() + except Exception: + return + + # Check if all other observers have the same scale and zero_point + for i, observer in enumerate(input_observers[1:], start=1): + try: + scale, zp = observer.calculate_qparams() + except Exception: + results.append(f"Failed to calculate qparams for observer[{i}]") + continue + + scale_match = torch.allclose(ref_scale, scale, rtol=1e-5, atol=1e-8) + zp_match = torch.equal(ref_zp, zp) + + if not scale_match or not zp_match: + results.append( + f"observer[{i}] mismatch: ref_scale={ref_scale.item():.8f}, " + f"scale={scale.item():.8f}, ref_zp={ref_zp.item()}, zp={zp.item()}" + ) + + +def freeze_llama_rmsnorm_weight(m): + if isinstance(m, QRMSNorm): + m.freeze_weight() + + +def freeze_llama_linear_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.freeze_weight() + + +def freeze_llama_embed_tokens_weight(m): + if isinstance(m, QEmbedding): + m.freeze_weight() + + +def disable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.disable_observer() + + +def enable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.enable_observer() + + +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + if isinstance(m, QLinearLPBQ): + m.enable_fakequant() + if isinstance(m, QRMSNorm): + m.enable_fakequant() + if isinstance(m, QEmbedding): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + if isinstance(m, QLinearLPBQ): + m.disable_fakequant() + if isinstance(m, QRMSNorm): + m.disable_fakequant() + if isinstance(m, QEmbedding): + m.disable_fakequant() + + +def convert_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.convert_to_conv2d_deploy_hwio() + if isinstance(m, QRMSNorm): + m.convert_to_deploy() + if isinstance(m, QEmbedding): + m.convert_to_deploy() + + +class LlamaQuantizer: + def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = LlamaForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", + dtype=torch.float32, + ) + self.model.cuda() + self.mllm_qualcomm_max_length = mllm_qualcomm_max_length + self.model.mllm_qualcomm_max_length = mllm_qualcomm_max_length + + if self.model.config.tie_word_embeddings: + self.model.copy_lm_head_weight_from_embed_tokens() + + # PTQ All Weights. + self.model.apply(freeze_llama_rmsnorm_weight) + self.model.apply(freeze_llama_linear_weight) + self.model.apply(freeze_llama_embed_tokens_weight) + print("All PTQ weights preparation done.") + + def freeze_activation(self): + self.model.apply(disable_qdq_observer) + + def enable_activation_update(self): + self.model.apply(enable_qdq_observer) + + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + + def compile(self): + print("Compile Start.") + self.model = torch.compile( + self.model, mode="reduce-overhead", fullgraph=False, backend="inductor" + ) + print("Compile done.") + + def infer(self, prompt: str): + # Llama models typically don't use chat templates, so we tokenize directly + model_inputs = self.tokenizer([prompt], return_tensors="pt").to( + self.model.device + ) + + # conduct text completion + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=self.mllm_qualcomm_max_length + - len(model_inputs.input_ids[0]) + - 1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() + content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip( + "\n" + ) + + print("content:", content) + + def calibrate(self, num_samples=64, max_seq_length=512): + """ + Perform calibration using Wikipedia dataset (PTQ) + :param num_samples: Number of samples for calibration + :param max_seq_length: Maximum length for each sample (not exceeding mllm_qualcomm_max_length) + """ + print( + f"Starting calibration, samples: {num_samples}, max length: {max_seq_length}" + ) + + # 1. Enable QDQ Observer for activation values + self.enable_activation_update() + self.model.eval() + + # 2. Load Wikipedia dataset (English version example) + # Use streaming=True to download and process on the fly, without downloading the full几十G dataset + dataset = MsDataset.load( + "modelscope/wikitext", + subset_name="wikitext-103-v1", + split="train", + trust_remote_code=True, + ) + + # 3. Execute forward pass (Prefill stage) + samples_processed = 0 + + # Ensure no gradient calculation during inference + with torch.no_grad(): + pbar = tqdm(total=num_samples, desc="Calibrating") + for entry in dataset: + if samples_processed >= num_samples: + break + + if len(entry["text"].strip()) < 1024: + continue + + # Llama models typically don't use chat templates + text = entry["text"] + model_inputs = self.tokenizer( + [text], + return_tensors="pt", + max_length=max_seq_length, + truncation=True, + padding=False, + ).to(self.model.device) + + # Only need Prefill stage: directly call forward + # This will trigger observer update statistics in ActivationQDQ + self.model.generate( + **model_inputs, + max_new_tokens=1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + + samples_processed += 1 + pbar.update(1) + + # 4. Close Observer, freeze calibrated quantization parameters + self.freeze_activation() + print("\nCalibration completed, activation quantization parameters frozen.") + + def convert(self): + self.model.apply(convert_weight) + self.model.model.convert_rope_for_deploy() + + def recompute_scale_zp(self): + self.model.apply(recompute_scale_zp) + + def validate_concat_observer(self): + results = [] + for name, module in self.model.named_modules(): + validate_concat_observer_fn(module, results, name) + if results: + print("ConcatObserver validation FAILED:") + for msg in results: + print(f" {msg}") + raise ValueError("ConcatObserver validation FAILED") + else: + print( + "ConcatObserver validation PASSED: all observers have matching scale and zp" + ) + print("ConcatObserver validation done.", flush=True) diff --git a/pymllm/backends/qualcomm/transformers/llama/train.py b/pymllm/backends/qualcomm/transformers/llama/train.py new file mode 100644 index 00000000..cd10befb --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/llama/train.py @@ -0,0 +1,56 @@ +import os +import torch +import argparse +from safetensors.torch import save_model +from pymllm.backends.qualcomm.transformers.llama.runner import LlamaQuantizer + + +def main(): + parser = argparse.ArgumentParser(description="Llama Quantizer for Qualcomm backend") + parser.add_argument( + "--model_path", + type=str, + default="", + help="Path to the Llama model directory", + ) + parser.add_argument( + "--max_length", + type=int, + default=2048, + help="Maximum sequence length for quantization", + ) + parser.add_argument( + "--num_samples", type=int, default=128, help="Number of samples for calibration" + ) + parser.add_argument( + "--infer_text", + type=str, + default="为什么伟大不能被计划", + help="Text to run inference on", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Directory to save the quantized model", + ) + + args = parser.parse_args() + + m = LlamaQuantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() + m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) + m.enable_fake_quant() + m.recompute_scale_zp() + m.validate_concat_observer() + m.infer(args.infer_text) + m.convert() + + os.makedirs(args.output_dir, exist_ok=True) + model_save_path = os.path.join(args.output_dir, "model.safetensors") + save_model(m.model, model_save_path) + + +if __name__ == "__main__": + main() diff --git a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py new file mode 100644 index 00000000..56b19c42 --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -0,0 +1,804 @@ +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + +# Replace linear, rms_norm with: +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, +) +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = QLinearLPBQ( + self.hidden_size, self.intermediate_size, bias=False, block_size=32 + ) + self.up_proj = QLinearLPBQ( + self.hidden_size, self.intermediate_size, bias=False, block_size=32 + ) + self.down_proj = QLinearLPBQ( + self.intermediate_size, self.hidden_size, bias=False, block_size=32 + ) + + # QDQ + self.up_proj_input_qdq = ActivationQDQ(bits=16) + self.up_proj_output_qdq = ActivationQDQ(bits=16) + self.gate_proj_output_qdq = ActivationQDQ(bits=16) + self.act_output_qdq = ActivationQDQ(bits=16) + self.down_proj_input_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) + + def forward(self, x): + x = self.up_proj_input_qdq(x) + up_result = self.up_proj_output_qdq(self.up_proj(x)) + gate_result = self.gate_proj_output_qdq(self.gate_proj(x)) + + # SiLU + gate_result = self.act_output_qdq( + gate_result * self.sigmoid_output_qdq(F.sigmoid(gate_result)) + ) + + o = self.down_proj_input_qdq(gate_result * up_result) + o = self.down_proj(o) + return o + + +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + attention_bias = getattr(config, "attention_bias", True) + self.q_proj = QLinearLPBQ( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=attention_bias, + block_size=32, + ) + self.k_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=attention_bias, + block_size=32, + ) + self.v_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=attention_bias, + block_size=32, + ) + self.o_proj = QLinearLPBQ( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=False, + block_size=32, + ) + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + # QDQ + self.q_proj_input_qdq = ActivationQDQ(bits=16) + self.k_proj_input_qdq = ActivationQDQ(bits=16) + + self.q_proj_output_qdq = ActivationQDQ(bits=16) + self.k_proj_output_qdq = ActivationQDQ(bits=16) + + self.v_proj_input_qdq = ActivationQDQ(bits=16) + self.q_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.q_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_proj_input_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_proj_input_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + + # In qnn, is uint8 sym. + self.k_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + self.v_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + + self.v_cast_to_int16_qdq = ActivationQDQ(bits=16) + self.qk_matmul_output_qdq = ActivationQDQ(bits=16) + self.scaling_qdq = ActivationQDQ(bits=16) + self.neg_20_qdq = ActivationQDQ(bits=16) + self.reduce_min_output_qdq = ActivationQDQ(bits=16) + self.mul_0_output_qdq = ActivationQDQ(bits=16) + self.minus_0_output_qdq = ActivationQDQ(bits=16) + self.softmax_output_qdq = ActivationQDQ(bits=16) + self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + hidden_states = self.q_proj_input_qdq(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self.q_proj_output_qdq(query_states) + + hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj_output_qdq(key_states) + + hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + query_states = self.q_rope_add_0_output_qdq( + self.q_rope_mul_0_output_qdq(query_states * cos) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) + ) + key_states = self.k_rope_add_0_output_qdq( + self.k_rope_mul_0_output_qdq(key_states * cos) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) + ) + + key_states = self.k_cast_to_int8_qdq(key_states) + value_states = self.v_cast_to_int8_qdq(self.v_cast_to_int16_qdq(value_states)) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = self.mul_0_output_qdq( + self.qk_matmul_output_qdq( + torch.matmul(query_states, key_states.transpose(2, 3)) + ) + * self.scaling_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * self.scaling + ) + ) + + attn_min = self.reduce_min_output_qdq( + torch.amin(attn_weights, dim=-1, keepdim=True) + ) + attn_vv = self.minus_0_output_qdq( + attn_min + + self.neg_20_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) + ) + ) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) + + attn_weights = self.softmax_output_qdq( + nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + ) + attn_output = self.attn_value_matmul_output_qdq( + torch.matmul(attn_weights, value_states) + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2RMSNorm(QRMSNorm): + def __init__(self, hidden_size, eps: float = 1e-6, quant_bits=16) -> None: + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__(hidden_size, eps=eps, quant_bits=quant_bits) + + +class Qwen2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.layer_dix = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.attention_type = config.layer_types[layer_idx] + + # QDQ + if self.layer_dix != 0: + self.input_layernorm_input_qdq = ActivationQDQ(bits=16) + self.add_0_lhs_input_qdq = ActivationQDQ(bits=16) + self.add_0_output_qdq = ActivationQDQ(bits=16) + self.add_1_lhs_input_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + if self.layer_dix != 0: + hidden_states = self.input_layernorm_input_qdq(hidden_states) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.add_0_output_qdq( + residual + self.add_0_lhs_input_qdq(hidden_states) + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) + return hidden_states + + +@auto_docstring +class Qwen2PreTrainedModel(PreTrainedModel): + config: Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen2DecoderLayer, + "attentions": Qwen2Attention, + } + + +class Qwen2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen2Model(Qwen2PreTrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = QEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, quant_bits=16 + ) + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Register sin and cos as buffers + self.register_buffer("mllm_max_sin_embedding", None) + self.register_buffer("mllm_max_cos_embedding", None) + self.sin_embedding_input_qdq = ActivationQDQ(bits=16) + self.cos_embedding_input_qdq = ActivationQDQ(bits=16) + self.norm_input_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def convert_rope_for_deploy(self): + sin_scale = self.sin_embedding_input_qdq.fake_quant.scale + sin_zero_point = self.sin_embedding_input_qdq.fake_quant.zero_point + sin_quant_min = self.sin_embedding_input_qdq.fake_quant.quant_min + sin_quant_max = self.sin_embedding_input_qdq.fake_quant.quant_max + + cos_scale = self.cos_embedding_input_qdq.fake_quant.scale + cos_zero_point = self.cos_embedding_input_qdq.fake_quant.zero_point + cos_quant_min = self.cos_embedding_input_qdq.fake_quant.quant_min + cos_quant_max = self.cos_embedding_input_qdq.fake_quant.quant_max + + sin_int = torch.round( + self.mllm_max_sin_embedding / sin_scale + sin_zero_point + ).clamp(sin_quant_min, sin_quant_max) + self.mllm_max_sin_embedding = sin_int.to(torch.uint16) + + cos_int = torch.round( + self.mllm_max_cos_embedding / cos_scale + cos_zero_point + ).clamp(cos_quant_min, cos_quant_max) + self.mllm_max_cos_embedding = cos_int.to(torch.uint16) + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + if self.mllm_max_sin_embedding is None and self.mllm_max_cos_embedding is None: + mllm_qualcomm_max_length = kwargs.get("mllm_qualcomm_max_length", None) + assert mllm_qualcomm_max_length is not None + max_position_ids = torch.arange( + 0, + mllm_qualcomm_max_length, + dtype=position_ids.dtype, + device=position_ids.device, + ).unsqueeze(0) + self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( + hidden_states, max_position_ids + ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( + self.mllm_max_cos_embedding + ) + self.mllm_max_sin_embedding = self.sin_embedding_input_qdq( + self.mllm_max_sin_embedding + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = ( + self.mllm_max_cos_embedding[:, position_ids.squeeze(0), :], + self.mllm_max_sin_embedding[:, position_ids.squeeze(0), :], + ) + + # Generate causal mask based on position_ids length + # For prefill, we need a lower triangular mask + _, seq_len = input_ids.shape + if seq_len != 1: + causal_mask = 1 - torch.tril( + torch.ones(seq_len, seq_len, dtype=torch.int8, device=input_ids.device) + ) + # [1, 1, seq_len, seq_len] + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + else: + # [1, 1, seq_len, seq_len] + causal_mask = torch.zeros( + (1, 1, 1, seq_len), dtype=torch.int8, device=input_ids.device + ) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(self.norm_input_qdq(hidden_states)) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.config = config + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = QLinearLPBQ( + config.hidden_size, config.vocab_size, bias=False, block_size=32 + ) + self.mllm_qualcomm_max_length = None + + self.lm_head_input_qdq = ActivationQDQ(bits=16) + self.lm_head_output_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def copy_lm_head_weight_from_embed_tokens(self): + if self.config.tie_word_embeddings: + self.lm_head.weight.copy_(self.model.embed_tokens.weight) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + kwargs.update({"mllm_qualcomm_max_length": self.mllm_qualcomm_max_length}) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head( + self.lm_head_input_qdq(hidden_states[:, slice_indices, :]) + ) + logits = self.lm_head_output_qdq(logits) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen2ForSequenceClassification( + GenericForSequenceClassification, Qwen2PreTrainedModel +): + pass + + +class Qwen2ForTokenClassification(GenericForTokenClassification, Qwen2PreTrainedModel): + pass + + +class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedModel): + base_model_prefix = ( + "transformer" # For BC, where `transformer` was used instead of `model` + ) + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2RMSNorm", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/pymllm/backends/qualcomm/transformers/qwen2/runner.py b/pymllm/backends/qualcomm/transformers/qwen2/runner.py new file mode 100644 index 00000000..d2f5be05 --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/qwen2/runner.py @@ -0,0 +1,352 @@ +import torch +from tqdm import tqdm +from modelscope.msdatasets import MsDataset +from transformers import AutoTokenizer +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, + QLinearW8A16_PerChannelSym, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +def recompute_scale_zp(module): + """ + Callback function: Used to forcefully refresh scale and zero_point of all FakeQuantize modules after calibration. + + Problem solved: + When using ConcatObserver, min/max may be updated during forward pass, + but at the end of forward, the scale/zp stored in FakeQuantize's internal buffer are still computed from old min/max. + This function forces a calculate_qparams call to sync the latest parameters to the buffer. + + Usage: + model.apply(recompute_scale_zp) + """ + + # We mainly focus on FakeQuantize modules since they store the scale/zero_point buffers + # Note: model.apply recursively traverses all submodules, so self.fake_quant inside ActivationQDQ will also be visited + if isinstance(module, ActivationQDQ): + observer = module.fake_quant.activation_post_process + + # 2. Check if observer is valid and contains statistics + # We only care about MinMaxObserver or MovingAverageMinMaxObserver that have min_val/max_val + if hasattr(observer, "min_val") and hasattr(observer, "max_val"): + # 3. Check if data is initialized + # If min_val is still the initial inf, this layer hasn't processed data, skip to avoid errors + if observer.min_val.numel() == 0 or observer.max_val.numel() == 0: + return + if ( + torch.isinf(observer.min_val).any() + or torch.isinf(observer.max_val).any() + ): + return + + # 4. Recompute Scale and Zero Point + # calculate_qparams reads the current min_val/max_val from observer (may have been modified by ConcatObserver) + try: + scale, zero_point = observer.calculate_qparams() + except Exception as e: + # Some special Observers (e.g., FixedQParams) may not support recomputation or behave differently, safely skip + print(e) + return + + # 5. Force overwrite the computed results to FakeQuantize's Buffer + # Use copy_ to keep reference unchanged, ensuring the new values are used during export + if ( + hasattr(module.fake_quant, "scale") + and module.fake_quant.scale is not None + ): + # Ensure dimension match (handle per-channel vs per-tensor) + if module.fake_quant.scale.shape != scale.shape: + module.fake_quant.scale.resize_(scale.shape) + module.fake_quant.scale.copy_(scale) + # Try to get the registered name of module scale from _parameters or _buffers + for key, value in module.fake_quant.named_parameters(): + if value is module.fake_quant.scale: + print(f"{module._get_name()}.{key}: {module.scale}") + break + + if ( + hasattr(module.fake_quant, "zero_point") + and module.fake_quant.zero_point is not None + ): + if module.fake_quant.zero_point.shape != zero_point.shape: + module.fake_quant.zero_point.resize_(zero_point.shape) + module.fake_quant.zero_point.copy_(zero_point) + + +def validate_concat_observer_fn(module, results: list, name: str = ""): + """ + Callback function: Validate that all input_observers in ConcatObserver have consistent scale and zero_point. + + Usage: + results = [] + for name, m in model.named_modules(): + validate_concat_observer_fn(m, results, name) + """ + if not isinstance(module, ConcatObserver): + return + + input_observers = module.input_observers + if len(input_observers) == 0: + return + + # Collect scale and zero_point from all observers + scales_zps = [] + for i, observer in enumerate(input_observers): + try: + scale, zp = observer.calculate_qparams() + scales_zps.append(f"[{i}] s={scale.item():.8f} zp={zp.item()}") + except Exception: + scales_zps.append(f"[{i}] failed") + + # Print one line: scale and zp of all inputs for each concat observer + print(f"ConcatObserver [{name}]: {' | '.join(scales_zps)}") + + # Original validation logic + if len(input_observers) <= 1: + return + + # Get scale and zero_point from the first observer as reference + first_observer = input_observers[0] + try: + ref_scale, ref_zp = first_observer.calculate_qparams() + except Exception: + return + + # Check if all other observers have the same scale and zero_point + for i, observer in enumerate(input_observers[1:], start=1): + try: + scale, zp = observer.calculate_qparams() + except Exception: + results.append(f"Failed to calculate qparams for observer[{i}]") + continue + + scale_match = torch.allclose(ref_scale, scale, rtol=1e-5, atol=1e-8) + zp_match = torch.equal(ref_zp, zp) + + if not scale_match or not zp_match: + results.append( + f"observer[{i}] mismatch: ref_scale={ref_scale.item():.8f}, " + f"scale={scale.item():.8f}, ref_zp={ref_zp.item()}, zp={zp.item()}" + ) + + +def freeze_qwen2_rmsnorm_weight(m): + if isinstance(m, QRMSNorm): + m.freeze_weight() + + +def freeze_qwen2_linear_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.freeze_weight() + + +def freeze_qwen2_embed_tokens_weight(m): + if isinstance(m, QEmbedding): + m.freeze_weight() + + +def disable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.disable_observer() + + +def enable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.enable_observer() + + +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + if isinstance(m, QLinearLPBQ): + m.enable_fakequant() + if isinstance(m, QRMSNorm): + m.enable_fakequant() + if isinstance(m, QEmbedding): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + if isinstance(m, QLinearLPBQ): + m.disable_fakequant() + if isinstance(m, QRMSNorm): + m.disable_fakequant() + if isinstance(m, QEmbedding): + m.disable_fakequant() + + +def convert_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.convert_to_conv2d_deploy_hwio() + if isinstance(m, QRMSNorm): + m.convert_to_deploy() + if isinstance(m, QEmbedding): + m.convert_to_deploy() + + +class Qwen2Quantizer: + def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = Qwen2ForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", + dtype=torch.float32, + ) + self.model.cuda() + self.mllm_qualcomm_max_length = mllm_qualcomm_max_length + self.model.mllm_qualcomm_max_length = mllm_qualcomm_max_length + + if self.model.config.tie_word_embeddings: + self.model.copy_lm_head_weight_from_embed_tokens() + + # PTQ All Weights. + self.model.apply(freeze_qwen2_rmsnorm_weight) + self.model.apply(freeze_qwen2_linear_weight) + self.model.apply(freeze_qwen2_embed_tokens_weight) + print("All PTQ weights preparation done.") + + def freeze_activation(self): + self.model.apply(disable_qdq_observer) + + def enable_activation_update(self): + self.model.apply(enable_qdq_observer) + + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + + def compile(self): + print("Compile Start.") + self.model = torch.compile( + self.model, mode="reduce-overhead", fullgraph=False, backend="inductor" + ) + print("Compile done.") + + def infer(self, prompt: str): + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + + # conduct text completion + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=self.mllm_qualcomm_max_length + - len(model_inputs.input_ids[0]) + - 1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() + content = self.tokenizer.decode( + output_ids, skip_special_tokens=True + ).strip("\n") + + print("content:", content) + + def calibrate(self, num_samples=64, max_seq_length=512): + """ + Perform calibration using Wikipedia dataset (PTQ) + :param num_samples: Number of samples for calibration + :param max_seq_length: Maximum length for each sample (not exceeding mllm_qualcomm_max_length) + """ + print( + f"Starting calibration, samples: {num_samples}, max length: {max_seq_length}" + ) + + # 1. Enable QDQ Observer for activation values + self.enable_activation_update() + self.model.eval() + + # 2. Load Wikipedia dataset (English version example) + # Use streaming=True to download and process on the fly, without downloading the full几十G dataset + dataset = MsDataset.load( + "modelscope/wikitext", + subset_name="wikitext-103-v1", + split="train", + trust_remote_code=True, + ) + + # 3. Execute forward pass (Prefill stage) + samples_processed = 0 + + # Ensure no gradient calculation during inference + with torch.no_grad(): + pbar = tqdm(total=num_samples, desc="Calibrating") + for entry in dataset: + if samples_processed >= num_samples: + break + + if len(entry["text"].strip()) < 1024: + continue + + messages = [{"role": "user", "content": entry["text"]}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = self.tokenizer( + [text], + return_tensors="pt", + max_length=max_seq_length, + truncation=True, + padding=False, + ).to(self.model.device) + + # Only need Prefill stage: directly call forward + # This will trigger observer update statistics in ActivationQDQ + self.model.generate( + **model_inputs, + max_new_tokens=1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + + samples_processed += 1 + pbar.update(1) + + # 4. Close Observer, freeze calibrated quantization parameters + self.freeze_activation() + print("\nCalibration completed, activation quantization parameters frozen.") + + def convert(self): + self.model.apply(convert_weight) + self.model.model.convert_rope_for_deploy() + + def recompute_scale_zp(self): + self.model.apply(recompute_scale_zp) + + def validate_concat_observer(self): + results = [] + for name, module in self.model.named_modules(): + validate_concat_observer_fn(module, results, name) + if results: + print("ConcatObserver validation FAILED:") + for msg in results: + print(f" {msg}") + raise ValueError("ConcatObserver validation FAILED") + else: + print( + "ConcatObserver validation PASSED: all observers have matching scale and zp" + ) + print("ConcatObserver validation done.", flush=True) diff --git a/pymllm/backends/qualcomm/transformers/qwen2/train.py b/pymllm/backends/qualcomm/transformers/qwen2/train.py new file mode 100644 index 00000000..fec5fdfc --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/qwen2/train.py @@ -0,0 +1,56 @@ +import os +import torch +import argparse +from safetensors.torch import save_model +from pymllm.backends.qualcomm.transformers.qwen2.runner import Qwen2Quantizer + + +def main(): + parser = argparse.ArgumentParser(description="Qwen2 Quantizer for Qualcomm backend") + parser.add_argument( + "--model_path", + type=str, + default="Qwen/Qwen2-1.5B", + help="Path to the Qwen2 model directory", + ) + parser.add_argument( + "--max_length", + type=int, + default=2048, + help="Maximum sequence length for quantization", + ) + parser.add_argument( + "--num_samples", type=int, default=128, help="Number of samples for calibration" + ) + parser.add_argument( + "--infer_text", + type=str, + default="为什么伟大不能被计划", + help="Text to run inference on", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Directory to save the quantized model", + ) + + args = parser.parse_args() + + m = Qwen2Quantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() + m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) + m.enable_fake_quant() + m.recompute_scale_zp() + m.validate_concat_observer() + m.infer(args.infer_text) + m.convert() + + os.makedirs(args.output_dir, exist_ok=True) + model_save_path = os.path.join(args.output_dir, "model.safetensors") + save_model(m.model, model_save_path) + + +if __name__ == "__main__": + main()