diff --git a/examples/qwen3_qnn_aot/compile.cpp b/examples/qwen3_qnn_aot/compile.cpp index 6404af3c..637d9421 100644 --- a/examples/qwen3_qnn_aot/compile.cpp +++ b/examples/qwen3_qnn_aot/compile.cpp @@ -1,13 +1,13 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. -#include #include #include #include #include #include +#include "compile_common.hpp" #include "modeling_qwen_qnn_aot.hpp" using mllm::Argparse; @@ -20,11 +20,11 @@ MLLM_MAIN({ auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") .help("QNN AOT Environment path."); + auto& output_context_path = Argparse::add("-o|--output_context_name").help("Output QNN context path."); Argparse::parse(argc, argv); - int N = 32; - int CL = 1024; + constexpr int kContextLength = 1024; if (help.isSet()) { Argparse::printHelp(); @@ -36,128 +36,33 @@ MLLM_MAIN({ Argparse::printHelp(); return -1; } + if (!output_context_path.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No output context path provided"); + Argparse::printHelp(); + return -1; + } auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); auto model = mllm::models::qwen3::Qwen3ForCausalLM(model_cfg); auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); - // Add params for causal mask - { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); - } + qwen3_qnn_aot::addCausalMaskParams(params); model.load(params); // Create Qnn AOT Model auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); - // Model length 32. - - { - // Sequence: [B, N] - // past_key_i: [B, H, D, CL-N] for each layer i - // past_value_i: [B, H, CL-N, D] for each layer i - // causal_mask: [B, 1, N, CL] - auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); - auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); - - // NOTE: force set causal mask to UInt16Asy - // NOTE: Attach scale and zero point to causal mask - { - causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); - causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); - causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); - } - - // Create KV cache inputs for all layers - std::unordered_map trace_inputs; - trace_inputs["sequence"] = sequence; - trace_inputs["causal_mask"] = causal_mask; - - for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { - auto past_key_name = "past_key_" + std::to_string(i); - auto past_value_name = "past_value_" + std::to_string(i); - - // clang-format off - trace_inputs[past_key_name] = mllm::Tensor::empty({ - 1, - model_cfg.num_key_value_heads, - model_cfg.head_dim, - CL - N, - }, mllm::kUInt8PerTensorSym); - trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); - - trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); - trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - - trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); - trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - // clang-format on - } - + auto trace_and_dump = [&](int seq_len, const std::string& mir_path) { + auto trace_inputs = qwen3_qnn_aot::makeTraceInputs(seq_len, kContextLength, model_cfg, params); auto ir = model.trace(trace_inputs, {}); - mllm::ir::PassManager pm(ir["model"]); pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); pm.run(); + mllm::redirect(mir_path, [&]() { mllm::print(ir["model"]); }); + }; - mllm::redirect("qwen3_qnn_aot_32.mir", [&]() { mllm::print(ir["model"]); }); - } - - // Model length 1. - { - N = 1; - - // Sequence: [B, N] - // past_key_i: [B, H, D, CL-N] for each layer i - // past_value_i: [B, H, CL-N, D] for each layer i - // causal_mask: [B, 1, N, CL] - auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); - auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); - - // NOTE: force set causal mask to UInt16Asy - // NOTE: Attach scale and zero point to causal mask - { - causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); - causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); - causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); - } - - // Create KV cache inputs for all layers - std::unordered_map trace_inputs; - trace_inputs["sequence"] = sequence; - trace_inputs["causal_mask"] = causal_mask; - for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { - auto past_key_name = "past_key_" + std::to_string(i); - auto past_value_name = "past_value_" + std::to_string(i); - - // clang-format off - trace_inputs[past_key_name] = mllm::Tensor::empty({ - 1, - model_cfg.num_key_value_heads, - model_cfg.head_dim, - CL - N, - }, mllm::kUInt8PerTensorSym); - trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); - - trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); - trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - - trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); - trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - // clang-format on - } - - auto ir = model.trace(trace_inputs, {}); - - mllm::ir::PassManager pm(ir["model"]); - pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); - pm.run(); - - mllm::redirect("qwen3_qnn_aot_1.mir", [&]() { mllm::print(ir["model"]); }); - } + trace_and_dump(32, "qwen3_qnn_aot_32.mir"); + trace_and_dump(1, "qwen3_qnn_aot_1.mir"); - qnn_aot_env.saveContext("context.0", "qwen3-1.7B-lpbq.bin"); + qnn_aot_env.saveContext("context.0", output_context_path.get()); }); diff --git a/examples/qwen3_qnn_aot/compile_common.hpp b/examples/qwen3_qnn_aot/compile_common.hpp new file mode 100644 index 00000000..817339fe --- /dev/null +++ b/examples/qwen3_qnn_aot/compile_common.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include +#include + +#include +#include + +namespace qwen3_qnn_aot { + +template +inline void addCausalMaskParams(const ParamsT& params) { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); +} + +template +inline std::unordered_map makeTraceInputs(int seq_len, + int context_len, + const mllm::models::qwen3::Qwen3Config& model_cfg, + const ParamsT& params) { + auto sequence = mllm::Tensor::zeros({1, seq_len}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, seq_len, context_len}, mllm::kUInt16); + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + context_len - seq_len, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + context_len - seq_len, + model_cfg.head_dim, + }, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", + params->pull("model.layers." + std::to_string(i) + + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale") + .impl(), + true); + trace_inputs[past_key_name].attach("zero_point", + params->pull("model.layers." + std::to_string(i) + + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point") + .impl(), + true); + trace_inputs[past_value_name].attach("scale", + params->pull("model.layers." + std::to_string(i) + + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale") + .impl(), + true); + trace_inputs[past_value_name].attach("zero_point", + params->pull("model.layers." + std::to_string(i) + + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point") + .impl(), + true); + } + + return trace_inputs; +} + +} // namespace qwen3_qnn_aot diff --git a/examples/qwen3_qnn_aot/compile_sha.cpp b/examples/qwen3_qnn_aot/compile_sha.cpp index 9f2629f6..5478b350 100644 --- a/examples/qwen3_qnn_aot/compile_sha.cpp +++ b/examples/qwen3_qnn_aot/compile_sha.cpp @@ -9,13 +9,13 @@ // Usage: // ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json -#include #include #include #include #include #include +#include "compile_common.hpp" #include "modeling_qwen_qnn_aot_sha.hpp" using mllm::Argparse; @@ -28,11 +28,11 @@ MLLM_MAIN({ auto& qnn_env_path = Argparse::add("-qnn_env|--qnn_env_path") .def("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/") .help("QNN AOT Environment path."); + auto& output_context_path = Argparse::add("-o|--output_context_name").help("Output QNN context path."); Argparse::parse(argc, argv); - int N = 32; - int CL = 1024; + constexpr int kContextLength = 1024; if (help.isSet()) { Argparse::printHelp(); @@ -44,6 +44,11 @@ MLLM_MAIN({ Argparse::printHelp(); return -1; } + if (!output_context_path.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No output context path provided"); + Argparse::printHelp(); + return -1; + } auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); @@ -66,130 +71,32 @@ MLLM_MAIN({ // Create SHA model auto model = mllm::models::qwen3::sha::Qwen3ForCausalLM_SHA(model_cfg); - // Add params for causal mask - { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); - } + qwen3_qnn_aot::addCausalMaskParams(params); model.load(params); // Create Qnn AOT Model auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv(qnn_env_path.get(), mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); - // Model length 32. - - { - // Sequence: [B, N] - // past_key_i: [B, H, D, CL-N] for each layer i - // past_value_i: [B, H, CL-N, D] for each layer i - // causal_mask: [B, 1, N, CL] - auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); - auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); - - // NOTE: force set causal mask to UInt16Asy - // NOTE: Attach scale and zero point to causal mask - { - causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); - causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); - causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); - } - - // Create KV cache inputs for all layers - std::unordered_map trace_inputs; - trace_inputs["sequence"] = sequence; - trace_inputs["causal_mask"] = causal_mask; - - for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { - auto past_key_name = "past_key_" + std::to_string(i); - auto past_value_name = "past_value_" + std::to_string(i); - - // clang-format off - trace_inputs[past_key_name] = mllm::Tensor::empty({ - 1, - model_cfg.num_key_value_heads, - model_cfg.head_dim, - CL - N, - }, mllm::kUInt8PerTensorSym); - trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); - - trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); - trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - - trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); - trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - // clang-format on - } - - mllm::print("Tracing SHA model (seq=32)..."); + auto trace_and_dump = [&](int seq_len, const std::string& mir_path) { + auto trace_inputs = qwen3_qnn_aot::makeTraceInputs(seq_len, kContextLength, model_cfg, params); + mllm::print("Tracing SHA model (seq=" + std::to_string(seq_len) + ")..."); auto ir = model.trace(trace_inputs, {}); mllm::print("SHA model traced successfully."); - mllm::ir::PassManager pm(ir["model"]); pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); pm.run(); + mllm::redirect(mir_path, [&]() { mllm::print(ir["model"]); }); + }; - mllm::redirect("qwen3_qnn_aot_sha_32.mir", [&]() { mllm::print(ir["model"]); }); - } - - // Model length 1. - { - N = 1; - - // Sequence: [B, N] - // past_key_i: [B, H, D, CL-N] for each layer i - // past_value_i: [B, H, CL-N, D] for each layer i - // causal_mask: [B, 1, N, CL] - auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); - auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); - - // NOTE: force set causal mask to UInt16Asy - // NOTE: Attach scale and zero point to causal mask - { - causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); - causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); - causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); - } - - // Create KV cache inputs for all layers - std::unordered_map trace_inputs; - trace_inputs["sequence"] = sequence; - trace_inputs["causal_mask"] = causal_mask; - for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { - auto past_key_name = "past_key_" + std::to_string(i); - auto past_value_name = "past_value_" + std::to_string(i); - - // clang-format off - trace_inputs[past_key_name] = mllm::Tensor::empty({ - 1, - model_cfg.num_key_value_heads, - model_cfg.head_dim, - CL - N, - }, mllm::kUInt8PerTensorSym); - trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); - trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); - trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); - // clang-format on - } - - mllm::print("Tracing SHA model (seq=1)..."); - auto ir = model.trace(trace_inputs, {}); - mllm::print("SHA model traced successfully."); - - mllm::ir::PassManager pm(ir["model"]); - pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); - pm.run(); - - mllm::redirect("qwen3_qnn_aot_sha_1.mir", [&]() { mllm::print(ir["model"]); }); - } + trace_and_dump(32, "qwen3_qnn_aot_sha_32.mir"); + trace_and_dump(1, "qwen3_qnn_aot_sha_1.mir"); - qnn_aot_env.saveContext("context.0", "qwen3-1.7B-lpbq-sha.bin"); + qnn_aot_env.saveContext("context.0", output_context_path.get()); mllm::print("SHA compilation completed successfully!"); mllm::print("Output files:"); mllm::print(" - qwen3_qnn_aot_sha_32.mir (IR dump for seq=32)"); mllm::print(" - qwen3_qnn_aot_sha_1.mir (IR dump for seq=1)"); - mllm::print(" - qwen3-1.7B-lpbq-sha.bin (QNN context)"); + mllm::print(" - " + output_context_path.get() + " (QNN context)"); }); diff --git a/examples/qwen3_qnn_aot/config_1.7B.json b/examples/qwen3_qnn_aot/config_1.7B.json index 4b91c87c..1a3d6f7e 100644 --- a/examples/qwen3_qnn_aot/config_1.7B.json +++ b/examples/qwen3_qnn_aot/config_1.7B.json @@ -28,5 +28,6 @@ "use_sliding_window": false, "vocab_size": 151936, "max_cache_length": 2048, + "linear_block_size": 16, "linear_impl_type": "QNN_LPBQ_w4a16o16_G16" } diff --git a/examples/qwen3_qnn_aot/config_4B.json b/examples/qwen3_qnn_aot/config_4B.json index dc9fdc08..17883e09 100644 --- a/examples/qwen3_qnn_aot/config_4B.json +++ b/examples/qwen3_qnn_aot/config_4B.json @@ -17,5 +17,6 @@ "max_cache_length": 2048, "rope_theta": 5000000.0, "tie_word_embeddings": true, + "linear_block_size": 32, "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" } diff --git a/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 6a8788ba..c48be055 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -58,20 +58,41 @@ from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver +def normalize_qwen3_lpbq_block_size(config: Qwen3Config) -> int: + block_size = getattr(config, "linear_block_size", None) + if isinstance(block_size, int) and block_size > 0: + config.linear_block_size = block_size + return block_size + + raise ValueError( + "Qwen3 LPBQ requires a positive `linear_block_size` in model config" + ) + + class Qwen3MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size + self.block_size = config.linear_block_size self.gate_proj = QLinearLPBQ( - self.hidden_size, self.intermediate_size, bias=False, block_size=16 + self.hidden_size, + self.intermediate_size, + bias=False, + block_size=self.block_size, ) self.up_proj = QLinearLPBQ( - self.hidden_size, self.intermediate_size, bias=False, block_size=16 + self.hidden_size, + self.intermediate_size, + bias=False, + block_size=self.block_size, ) self.down_proj = QLinearLPBQ( - self.intermediate_size, self.hidden_size, bias=False, block_size=16 + self.intermediate_size, + self.hidden_size, + bias=False, + block_size=self.block_size, ) # QDQ @@ -159,6 +180,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx + self.block_size = config.linear_block_size self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) @@ -173,25 +195,25 @@ def __init__(self, config: Qwen3Config, layer_idx: int): config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, - block_size=16, + block_size=self.block_size, ) self.k_proj = QLinearLPBQ( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, - block_size=16, + block_size=self.block_size, ) self.v_proj = QLinearLPBQ( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, - block_size=16, + block_size=self.block_size, ) self.o_proj = QLinearLPBQ( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, - block_size=16, + block_size=self.block_size, ) self.q_norm = QRMSNorm( self.head_dim, eps=config.rms_norm_eps, quant_bits=16 @@ -514,6 +536,7 @@ def forward(self, x, position_ids): class Qwen3Model(Qwen3PreTrainedModel): def __init__(self, config: Qwen3Config): super().__init__(config) + normalize_qwen3_lpbq_block_size(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = QEmbedding( @@ -698,8 +721,12 @@ def __init__(self, config): self.config = config self.model = Qwen3Model(config) self.vocab_size = config.vocab_size + self.block_size = config.linear_block_size self.lm_head = QLinearLPBQ( - config.hidden_size, config.vocab_size, bias=False, block_size=16 + config.hidden_size, + config.vocab_size, + bias=False, + block_size=self.block_size, ) self.mllm_qualcomm_max_length = None diff --git a/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py index 3d6cc9f6..2b0640fd 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py @@ -231,6 +231,23 @@ def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.model.apply(freeze_qwen3_embed_tokens_weight) print("All PTQ weights preparation done.") + def _build_model_inputs(self, prompt: str, max_length: int | None = None): + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, # Switches between thinking and non-thinking modes. Default is True. + ) + tokenizer_kwargs = {"return_tensors": "pt"} + if max_length is not None: + tokenizer_kwargs.update( + max_length=max_length, + truncation=True, + padding=False, + ) + return self.tokenizer([text], **tokenizer_kwargs).to(self.model.device) + def freeze_activation(self): self.model.apply(disable_qdq_observer) @@ -250,28 +267,24 @@ def compile(self): ) print("Compile done.") - def infer(self, prompt: str): - messages = [{"role": "user", "content": prompt}] - text = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, # Switches between thinking and non-thinking modes. Default is True. - ) - model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + def infer(self, prompt: str, max_new_tokens: int = 8): + model_inputs = self._build_model_inputs(prompt) + input_length = model_inputs.input_ids.shape[1] + available_tokens = self.mllm_qualcomm_max_length - input_length - 1 + if available_tokens < 1: + raise ValueError("Prompt exceeds configured mllm_qualcomm_max_length") + max_new_tokens = min(max_new_tokens, available_tokens) # conduct text completion generated_ids = self.model.generate( **model_inputs, - max_new_tokens=self.mllm_qualcomm_max_length - - len(model_inputs.input_ids[0]) - - 1, + max_new_tokens=max_new_tokens, do_sample=False, temperature=None, top_p=None, top_k=None, ) - output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() + output_ids = generated_ids[0][input_length:].tolist() # parsing thinking content try: @@ -329,20 +342,9 @@ def calibrate(self, num_samples=64, max_seq_length=512): if len(entry["text"].strip()) < 1024: continue - messages = [{"role": "user", "content": entry["text"]}] - text = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, # Switches between thinking and non-thinking modes. Default is True. + model_inputs = self._build_model_inputs( + entry["text"], max_length=max_seq_length ) - model_inputs = self.tokenizer( - [text], - return_tensors="pt", - max_length=max_seq_length, - truncation=True, - padding=False, - ).to(self.model.device) # Only need Prefill stage: directly call forward # This will trigger observer update statistics in ActivationQDQ diff --git a/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py index f44fa67b..3736ce24 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py @@ -1,5 +1,4 @@ import os -import torch import argparse from safetensors.torch import save_model from pymllm.mobile.backends.qualcomm.transformers.qwen3.runner import Qwen3Quantizer @@ -28,6 +27,12 @@ def main(): default="为什么伟大不能被计划", help="Text to run inference on", ) + parser.add_argument( + "--infer_max_new_tokens", + type=int, + default=1024, + help="Maximum new tokens for post-calibration inference sanity check", + ) parser.add_argument( "--output_dir", type=str, @@ -36,15 +41,17 @@ def main(): args = parser.parse_args() - m = Qwen3Quantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + m = Qwen3Quantizer( + args.model_path, + mllm_qualcomm_max_length=args.max_length, + ) - # FIXME: Should disable or not. m.disable_fake_quant() m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) m.enable_fake_quant() m.recompute_scale_zp() m.validate_concat_observer() - m.infer(args.infer_text) + m.infer(args.infer_text, max_new_tokens=args.infer_max_new_tokens) m.convert() os.makedirs(args.output_dir, exist_ok=True)