Skip to content

feat(qwen3): Introduce Single Head Attention (SHA) optimization for Qualcomm qwen model#616

Merged
chenghuaWang merged 9 commits intoUbiquitousLearning:mainfrom
chenghuaWang:wch-main
Jan 29, 2026
Merged

feat(qwen3): Introduce Single Head Attention (SHA) optimization for Qualcomm qwen model#616
chenghuaWang merged 9 commits intoUbiquitousLearning:mainfrom
chenghuaWang:wch-main

Conversation

@chenghuaWang
Copy link
Copy Markdown
Collaborator

@chenghuaWang chenghuaWang commented Jan 29, 2026

Summary by CodeRabbit

  • New Features

    • Added LLaMA and Qwen2 model support with AOT compilation capabilities
    • Added multiple model size configurations (LLaMA 3B; Qwen2 1.5B, 3B, 7B)
    • Added SHA (per-head attention) variants for both model families
    • Added Python quantization and training tools for LLaMA and Qwen2
    • Added performance measurement support in generation pipeline
    • Added example compilation and runtime executables
  • Improvements

    • Enhanced token handling type consistency
    • Optimized QNN backend settings and logging defaults
    • Improved weight quantization validation

✏️ Tip: You can customize this high-level summary in your review settings.

chenghuaWang and others added 9 commits January 28, 2026 09:59
…NN runtime. Update default log level and improve token generation timing metrics for better performance analysis.
…NN AOT compilation. Add new executable for SHA model and implement weight slicing utilities to enhance performance and reduce compilation time. Update CMake configuration and include necessary headers for SHA implementation.
… new executable. Update input tensor dimensions and block sizes for improved performance. Introduce compile_sha utility for weight slicing and quantization support. Adjust CMake configurations and add new model files for better integration.
…roduce new executables for model compilation and runtime, enhance CMake configurations, and implement weight slicing utilities for improved performance. Include configuration files and update input handling for tensor dimensions in AOT run.
…handling in AOT run. Modify configuration for improved performance with new target machine settings and quantization parameters.
…persAPI and PTQPass. Update weight type limits and add masking for Qnn's packing requirements in QLinearLPBQ for improved compatibility and performance.
…d quantization accuracy in AOT compilation across multiple models.
…ask and constant zero parameters in AOT compilation to enhance quantization accuracy across multiple models.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

This PR introduces comprehensive QNN AOT (Ahead-of-Time) compilation infrastructure for Llama and Qwen2 models. It adds C++ model implementations with quantization support, Python PyTorch backends with quantization pipelines, configuration files, build system updates, and runtime performance measurement capabilities alongside QNN backend infrastructure modifications.

Changes

Cohort / File(s) Summary
Build System Configuration
examples/CMakeLists.txt, examples/llama_qnn_aot/CMakeLists.txt, examples/qwen2_qnn_aot/CMakeLists.txt, examples/qwen3_qnn_aot/CMakeLists.txt
Added QNN AOT build targets for Llama and Qwen2 examples; introduced conditional executable targets for compilation and runtime; updated Qwen3 with SHA-variant compilation support.
Llama QNN AOT C++ Implementation
examples/llama_qnn_aot/configuration_llama3.hpp, examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp, examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp, examples/llama_qnn_aot/compile.cpp, examples/llama_qnn_aot/compile_sha.cpp, examples/llama_qnn_aot/aot_run.cpp
New Llama3 configuration holder and comprehensive QNN AOT model implementations with standard and SHA (per-head attention) variants; compilation and runtime tools for tracing, lowering, and IR generation.
Llama QNN AOT Configuration
examples/llama_qnn_aot/config_3B.json, examples/llama_qnn_aot/qnn_aot_cfg_3B.json
Model architecture configuration (hidden layers, attention heads, vocab size, quantization types) and QNN AOT hardware/quantization recipe (V75 HTP, LPBQ w4a16, kv_cache w8a8).
Qwen2 QNN AOT C++ Implementation
examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp, examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp, examples/qwen2_qnn_aot/compile.cpp, examples/qwen2_qnn_aot/compile_sha.cpp, examples/qwen2_qnn_aot/aot_run.cpp
Qwen2 QNN AOT model with standard and SHA variants; compilation and runtime tools for IR generation and token-by-token inference.
Qwen2 QNN AOT Configuration
examples/qwen2_qnn_aot/config_*.json, examples/qwen2_qnn_aot/qnn_aot_cfg_*.json
Model configs for 1.5B, 3B, and 7B variants with architecture parameters; corresponding QNN AOT hardware and quantization recipes for each model size.
Qwen3 QNN AOT Updates
examples/qwen3_qnn_aot/compile.cpp, examples/qwen3_qnn_aot/compile_sha.cpp
Modified compile.cpp to support two-length IR generation (N=32 and N=1) with AOT environment setup and context persistence; new compile_sha.cpp for SHA parameter preparation.
QNN Backend Infrastructure
mllm/backends/qnn/QNNBackend.cpp, mllm/backends/qnn/aot/QnnWrappersAPI.cpp, mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp, mllm/backends/qnn/aot/passes/PTQPass.cpp
Adjusted QNN logging defaults and power optimization; added LPBQ quantization validation; expanded concat pattern support for multi-tensor concatenation; changed weight quantization range for int4 from [-8,7] to [0,15].
QNN AOT Runtime
mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp, mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp, mllm/backends/qnn/aot_rt/TokenGenerator.hpp, mllm/backends/qnn/aot_rt/TokenGenerator.cpp, mllm/backends/qnn/aot_rt/PromptProcessor.cpp
Added optional performance measurement flag to Runner::generate; introduced prefill/decode timing and TPS calculation; changed token type from uint64_t to int64_t for consistency.
Python Llama Backend
pymllm/backends/qualcomm/transformers/llama/modeling_llama.py, pymllm/backends/qualcomm/transformers/llama/runner.py, pymllm/backends/qualcomm/transformers/llama/train.py
Complete Qualcomm-optimized PyTorch Llama implementation with quantized modules (RMSNorm, MLP, Attention) and RoPE support; LlamaQuantizer for PTQ calibration and deployment conversion; training script for quantization pipeline.
Python Qwen2 Backend
pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py, pymllm/backends/qualcomm/transformers/qwen2/runner.py, pymllm/backends/qualcomm/transformers/qwen2/train.py
Complete Qualcomm-optimized PyTorch Qwen2 implementation with quantized modules and RoPE/sliding attention support; Qwen2Quantizer for PTQ calibration and conversion; training script matching Llama's quantization flow.
Quantization Tooling
pymllm/backends/qualcomm/transformers/core/qlinear.py
Added 4-bit packing mask (0x0F) in QLinearLPBQ weight deployment path for QNN compatibility.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Compiler as QNN AOT Compiler
    participant Model as Llama/Qwen2 Model
    participant Tracer as IR Tracer
    participant Lowerer as QNN AOT Lowering Pipeline
    participant Artifacts as Output Artifacts

    User->>Compiler: Load model config & params
    Compiler->>Compiler: Initialize QNN AOT Environment
    Compiler->>Model: Create model instance with KV cache tensors
    
    Note over Compiler,Model: First Pass (seq_len=32)
    Compiler->>Tracer: Build trace inputs for N=32
    Tracer->>Model: Perform model.trace()
    Model->>Tracer: Generate IR
    Tracer->>Lowerer: Apply QNN AOT lowering pipeline
    Lowerer->>Artifacts: Output MIR file (_32.mir)
    
    Note over Compiler,Model: Second Pass (seq_len=1)
    Compiler->>Tracer: Build trace inputs for N=1
    Tracer->>Model: Perform model.trace()
    Model->>Tracer: Generate IR
    Tracer->>Lowerer: Apply QNN AOT lowering pipeline
    Lowerer->>Artifacts: Output MIR file (_1.mir)
    
    Compiler->>Artifacts: Save QNN context (context.0)
    Artifacts->>User: Compiled AOT artifacts ready
Loading
sequenceDiagram
    participant User
    participant Runtime as QNN AOT Runtime
    participant Processor as PromptProcessor
    participant Generator as TokenGenerator
    participant Backend as QNN Backend
    participant Output as Token Output

    User->>Runtime: Provide prompt & config
    Runtime->>Runtime: Initialize QNN backend
    Runtime->>Runtime: Load compiled model
    
    Note over Runtime,Generator: Prefill Phase
    Runtime->>Processor: prefill(prompt_tokens)
    Processor->>Backend: Execute model inference
    Backend->>Processor: Return hidden states
    
    Note over Runtime,Generator: Decode Phase (Autoregressive)
    loop Generate tokens
        Runtime->>Generator: generate(tokens, seq_len, callback)
        Generator->>Backend: Execute model for next token
        Backend->>Generator: Return logits
        Generator->>Generator: Sample next token
        Generator->>Output: Stream token via callback
        Runtime->>Runtime: Measure timing (if perf enabled)
    end
    
    Runtime->>User: Return completion with metrics (optional)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • #600: Introduces Conv2D-based model implementations and weight deployment paths that overlap with the Conv2D architecture used in new Llama/Qwen2 QNN AOT models.
  • #562: Adds the foundational QNN AOT x86 flag (MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) and build wiring that gates the new QNN AOT example targets introduced in this PR.
  • #603: Modifies the same QNN AOT runtime components (Runner, TokenGenerator, PromptProcessor) with overlapping changes to the backend infrastructure used by the new examples.

Suggested reviewers

  • oreomaker
  • liang1232018
  • yirongjie

🐰 Whiskers twitching with delight

New AOT models hop into view,
Llama and Qwen2, quantized through and through,
Python and C++ dance in harmony,
Compilation and runtime, a perfect spree! ✨
Performance metrics measured with care—
This magical framework goes everywhere! 🚀

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning No pull request description was provided by the author, despite the template being available and guidelines being referenced. The description should document the purpose, scope, and testing of the SHA optimization changes. Add a comprehensive PR description following the provided template, including: rationale for SHA optimization, scope of changes (qwen3/qwen2/llama), testing performed, and any breaking changes or migration notes.
Docstring Coverage ⚠️ Warning Docstring coverage is 12.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: introducing Single Head Attention (SHA) optimization for Qualcomm Qwen models, which is supported by extensive additions across qwen3, qwen2, and llama implementations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Owner

@UbiquitousLearning UbiquitousLearning left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Note

Due to the large number of review comments, Critical severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp (1)

1067-1095: Fix TODO comment format and implement quantization recipe transformation

The LLMQuantRecipeQwen3AttentionPattern::rewrite function matches the Q, K, V, and O linear projections but returns true without applying any quantization recipe transformation. The TODO comment at line 1092 lacks the required colon format (TODO: instead of TODO).

Either implement the quantization logic or document why the pattern match alone is sufficient. Additionally, correct the TODO comment to follow the required format: TODO: Maybe something need to be done here!

mllm/backends/qnn/aot/QnnWrappersAPI.cpp (1)

181-199: Guard cfg->block_size before division (LPBQ).

Line 196 divides by cfg->block_size without validation; a zero or non-divisible block size will crash or truncate. Add a defensive assert before the division.

🛡️ Proposed fix
       Qnn_BlockwiseExpansion_t blockwise_expansion;
       blockwise_expansion.axis = v->tensor_.rank() - 1;
       blockwise_expansion.scaleOffsets = nullptr;  // Will be set by setBlockwiseQuantization
+      MLLM_RT_ASSERT(cfg->block_size > 0);
+      MLLM_RT_ASSERT_EQ(v->tensor_.size(-2) % cfg->block_size, 0u);
       blockwise_expansion.numBlocksPerAxis = v->tensor_.size(-2) / cfg->block_size;

As per coding guidelines: Validate inputs for public APIs and critical internal functions.

mllm/backends/qnn/aot_rt/TokenGenerator.cpp (1)

101-142: Guard against empty vectors and validate token IDs to prevent undefined behavior and silent wraparound.

tokens.back() is undefined behavior for empty vectors, and negative or out-of-range token IDs will silently wrap when converted to uint64_t (for prepare_io) or uint32_t (for input_ids). Add runtime assertions to validate preconditions.

🔧 Proposed fix
+  MLLM_RT_ASSERT(!tokens.empty());
   int64_t current_pos = start_pos;
   int64_t next_token = tokens.back();
+  MLLM_RT_ASSERT(next_token >= 0 && next_token < static_cast<int64_t>(config_.vocab_size));
-    prepare_io(next_token, current_pos);
+    prepare_io(static_cast<uint64_t>(next_token), current_pos);
-    if (eos_ids_ && eos_ids_->count(next_token)) { break; }
+    if (eos_ids_ && eos_ids_->count(static_cast<uint64_t>(next_token))) { break; }

Per coding guidelines: Validate inputs for public APIs and critical internal functions. Ensure proper error handling.

examples/qwen3_qnn_aot/compile.cpp (1)

15-35: Mark model and config paths as required arguments.

Both model_path and model_cfg_path are used unconditionally in the compile flow (lines 36 and 38) without validation. Mark them as required to prevent runtime failures.

🔧 Proposed fix
-  auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model file path.");
-  auto& model_cfg_path = Argparse::add<std::string>("-c|--config").help("Model config file path.");
+  auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model file path.").required(true);
+  auto& model_cfg_path = Argparse::add<std::string>("-c|--config").help("Model config file path.").required(true);
🤖 Fix all issues with AI agents
In `@examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp`:
- Around line 539-543: lm_head_ is registered only when cfg.tie_word_embeddings
is true but is invoked unconditionally in trace(), which can lead to use of an
uninitialized module; fix by either always registering lm_head_ (register
lm_head_ regardless of cfg.tie_word_embeddings so lm_head_ is valid) or by
guarding the call in trace() with the tie flag (use tie_word_embeddings_ or
cfg.tie_word_embeddings to wrap the lm_head_() invocation so it’s only called
when registered), and ensure any related QDQ calls are likewise guarded or
present only when lm_head_ exists.

In `@examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp`:
- Around line 415-419: The field lm_head_ is only registered inside the if
(cfg.tie_word_embeddings) branch but later invoked unconditionally (lm_head_()),
causing undefined behavior when tie_word_embeddings is false; always register
lm_head_ by moving or duplicating the reg<nn::Conv2D>("lm_head",
cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY) call out of the conditional so
lm_head_ is constructed regardless of cfg.tie_word_embeddings, and keep the
existing comment/logic that when tie_word_embeddings is true the lm_head_ weight
corresponds to model.embed_tokens.weight (adjust any weight-tying code elsewhere
to handle the tied vs untied cases).

In `@pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py`:
- Around line 90-123: The rotate_half call in apply_rotary_pos_emb fails because
rotate_half currently requires x_observer, x2_neg_fake_quant and
concat_observer; change rotate_half signature to make those three parameters
optional (e.g., default to None) and update its body to branch: if
x2_neg_fake_quant and concat_observer are provided, perform the quantized path
using x_observer/x2_neg_fake_quant/concat_observer exactly as before; otherwise
perform the plain rotation (slice x into halves, negate second half, concat) and
return the non-quantized tensor. Keep apply_rotary_pos_emb unchanged (calling
rotate_half(q)/rotate_half(k)) so it falls back to the non-quantized behavior
when observers/fake-quantizers are not supplied.

In `@pymllm/backends/qualcomm/transformers/qwen2/train.py`:
- Around line 31-35: The --output_dir CLI argument is not marked required which
can lead to args.output_dir being None and causing os.makedirs(args.output_dir,
exist_ok=True) in train.py to raise an error; update the parser.add_argument
call for "--output_dir" (the one creating the output_dir flag) to include
required=True (and optionally add a helpful help message) so args.output_dir is
always set before the os.makedirs(args.output_dir, ...) call.
🟠 Major comments (19)
pymllm/backends/qualcomm/transformers/qwen2/runner.py-70-73 (1)

70-73: Bug: Incorrect attribute reference module.scale.

Line 72 references module.scale, but based on the surrounding code context, this should be module.fake_quant.scale. The scale attribute belongs to fake_quant, not directly to the ActivationQDQ module.

🐛 Proposed fix
             for key, value in module.fake_quant.named_parameters():
                 if value is module.fake_quant.scale:
-                    print(f"{module._get_name()}.{key}: {module.scale}")
+                    print(f"{module._get_name()}.{key}: {module.fake_quant.scale}")
                     break
mllm/backends/qnn/aot/passes/PTQPass.cpp-50-50 (1)

50-50: Use spec-driven quantization range instead of hardcoding unsigned int4 bounds.

The hardcoded 0–15 range assumes unsigned int4 weights, but QuantizationSpecLPBQ supports both signed and unsigned int4 formats. The codebase contains signed int4 recipes with ranges like -8..7 and -7..7 (see LLMQuantRecipePass.cpp). This assertion will fail at runtime for those cases. Use this_spec->quant_min and this_spec->quant_max instead, matching the pattern used by AsymPerTensor and SymPerTensor specs.

examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp-636-638 (1)

636-638: Storing const Qwen3Config& as member risks dangling reference.

If the Qwen3Config passed to the constructor is a temporary or goes out of scope, cfg will become a dangling reference. Consider storing by value or using std::shared_ptr.

🔧 Proposed fix
  private:
-  const Qwen3Config& cfg;
+  Qwen3Config cfg;
examples/llama_qnn_aot/configuration_llama3.hpp-23-42 (1)

23-42: Guard against invalid num_attention_heads when computing head_dim.

If num_attention_heads is 0 (from a missing or invalid JSON field), line 41 will divide by zero. Add a validation guard before the division.

Suggested fix
    if (data().contains("head_dim")) {
      head_dim = data()["head_dim"];
    } else {
      if (num_attention_heads <= 0) {
        throw std::invalid_argument("num_attention_heads must be > 0");
      }
      if (hidden_size % num_attention_heads != 0) {
        throw std::invalid_argument("hidden_size must be divisible by num_attention_heads");
      }
      head_dim = hidden_size / num_attention_heads;
    }
examples/llama_qnn_aot/aot_run.cpp-46-55 (1)

46-55: The arange assignment completely ignores user input and creates a dtype mismatch.

The code overwrites the tokenized sequence from convertMessage() with synthetic sequential indices, discarding the "hello" input. Additionally, the code uses kInt64 while the compile path expects kInt32, creating a type inconsistency that can cause runtime failures.

🔧 Suggested fix
-  input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1});
+  // Keep tokenizer output; if a synthetic input is required for testing, gate it behind a flag and match the model's expected dtype (kInt32).

Either preserve the tokenizer output or, if synthetic input is needed, gate it behind a command-line flag and use the correct dtype expected by the compiled model.

examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp-620-620 (1)

620-620: Storing config by const reference risks dangling reference.

cfg is stored as const Llama3Config&. If the config object passed to the constructor is destroyed before the LlamaForCausalLM_SHA instance, this becomes a dangling reference leading to undefined behavior.

🛡️ Proposed fix: store by value or shared_ptr
  private:
-  const Llama3Config& cfg;
+  Llama3Config cfg;
   LlamaTextSHA llm;

Or if sharing is intentional:

std::shared_ptr<const Llama3Config> cfg;
pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py-591-611 (1)

591-611: causal_mask_mapping is built but never used (sliding attention + padding masks ignored).
You compute per-layer masks but always pass a fixed causal_mask to every layer. Use the mapping based on decoder_layer.attention_type so sliding attention and any provided attention_mask take effect.

Proposed fix
-        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
-            hidden_states = decoder_layer(
-                hidden_states,
-                attention_mask=causal_mask,
+        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+            layer_mask = causal_mask_mapping.get(
+                decoder_layer.attention_type, causal_mask_mapping["full_attention"]
+            )
+            hidden_states = decoder_layer(
+                hidden_states,
+                attention_mask=layer_mask,
                 position_ids=position_ids,
                 past_key_values=past_key_values,
                 use_cache=use_cache,
                 cache_position=cache_position,
                 position_embeddings=position_embeddings,
                 **kwargs,
             )

Also applies to: 660-669

pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py-614-637 (1)

614-637: mllm_qualcomm_max_length hard requirement makes default forward unusable.
If callers don’t set it, the function asserts and fails on first call. Default to config.max_position_embeddings (or raise a clear error) and compute embeddings when either buffer is missing.

Proposed fix
-        if self.mllm_max_sin_embedding is None and self.mllm_max_cos_embedding is None:
-            mllm_qualcomm_max_length = kwargs.get("mllm_qualcomm_max_length", None)
-            assert mllm_qualcomm_max_length is not None
+        if self.mllm_max_sin_embedding is None or self.mllm_max_cos_embedding is None:
+            mllm_qualcomm_max_length = kwargs.get("mllm_qualcomm_max_length")
+            if mllm_qualcomm_max_length is None:
+                mllm_qualcomm_max_length = self.config.max_position_embeddings
+            if mllm_qualcomm_max_length is None:
+                raise ValueError("mllm_qualcomm_max_length must be provided")
             max_position_ids = torch.arange(
                 0,
                 mllm_qualcomm_max_length,
pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py-645-658 (1)

645-658: input_ids can be None but seq_len derives from it.
When inputs_embeds is provided, this raises. Use inputs_embeds/hidden_states to derive seq_len.

Proposed fix
-        _, seq_len = input_ids.shape
+        seq_len = (
+            inputs_embeds.shape[1]
+            if input_ids is None
+            else input_ids.shape[1]
+        )
pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py-267-276 (1)

267-276: Avoid double-quantizing K/V inputs by overwriting hidden_states.
hidden_states is replaced by the Q path QDQ output and then fed to K/V QDQ, which applies quantization twice and defeats per-branch calibration. Keep the original input for K/V.

Proposed fix
-        hidden_states = self.q_proj_input_qdq(hidden_states)
-        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+        hidden_states_q = self.q_proj_input_qdq(hidden_states)
+        query_states = self.q_proj(hidden_states_q).view(hidden_shape).transpose(1, 2)
         query_states = self.q_proj_output_qdq(query_states)
 
-        hidden_states_k = self.k_proj_input_qdq(hidden_states)
+        hidden_states_k = self.k_proj_input_qdq(hidden_states)
         key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2)
         key_states = self.k_proj_output_qdq(key_states)
 
-        hidden_states_v = self.v_proj_input_qdq(hidden_states)
+        hidden_states_v = self.v_proj_input_qdq(hidden_states)
         value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2)
pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py-738-739 (1)

738-739: Unconditional kwargs.update overwrites caller-provided mllm_qualcomm_max_length.

The code at line 738 unconditionally overwrites any caller-provided mllm_qualcomm_max_length value with self.mllm_qualcomm_max_length, which defaults to None. This causes the guard in the initialization path (which asserts the value is not None) to fail. Only inject the attribute when it's set, and preserve any caller-provided value.

Proposed fix
-        kwargs.update({"mllm_qualcomm_max_length": self.mllm_qualcomm_max_length})
+        if self.mllm_qualcomm_max_length is not None:
+            kwargs.setdefault("mllm_qualcomm_max_length", self.mllm_qualcomm_max_length)
pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py-639-643 (1)

639-643: Position embedding indexing assumes batch size 1 and will fail with batch > 1.

position_ids.squeeze(0) at lines 641-642 only removes dimensions of size 1. When batch_size > 1, it leaves position_ids as a 2D tensor [batch_size, seq_len], breaking the indexing operation and producing mis-shaped embeddings passed to decoder layers. Either enforce batch=1 explicitly with a validation check in the forward method, or refactor the indexing to handle arbitrary batch sizes by iterating over batch dimension or using advanced indexing.

pymllm/backends/qualcomm/transformers/llama/train.py-9-35 (1)

9-35: Make --model_path and --output_dir required (current defaults lead to runtime errors).

Without these, the script fails later with unclear errors. Validate up front.

🛠️ Suggested CLI validation
     parser.add_argument(
         "--model_path",
         type=str,
-        default="",
+        required=True,
         help="Path to the Llama model directory",
     )
@@
     parser.add_argument(
         "--output_dir",
         type=str,
+        required=True,
         help="Directory to save the quantized model",
     )
As per coding guidelines: validate inputs for public APIs and critical internal functions.
pymllm/backends/qualcomm/transformers/llama/modeling_llama.py-163-214 (1)

163-214: Honor config.hidden_act (or assert it) to avoid silent activation mismatch.

The MLP always computes a SiLU‑style activation but ignores config.hidden_act. If a config uses a different activation, this silently produces the wrong model. If only SiLU is supported, fail fast.

🛠️ Suggested guard
     def forward(self, x):
         x = self.up_proj_input_qdq(x)
         up_result = self.up_proj_output_qdq(self.up_proj(x))
         gate_result = self.gate_proj_output_qdq(self.gate_proj(x))

-        # SiLU or other activation
+        if self.config.hidden_act != "silu":
+            raise ValueError(
+                "Qualcomm Llama MLP only supports SiLU; set hidden_act='silu'."
+            )
+        # SiLU activation
         gate_result = self.act_output_qdq(
             gate_result * self.sigmoid_output_qdq(F.sigmoid(gate_result))
         )
As per coding guidelines: validate inputs for public APIs and critical internal functions.
pymllm/backends/qualcomm/transformers/llama/modeling_llama.py-127-160 (1)

127-160: Fix rotate_half/apply_rotary_pos_emb signature mismatch (runtime TypeError).

apply_rotary_pos_emb calls rotate_half(q) with a single argument, but rotate_half currently requires four. This will crash the first time apply_rotary_pos_emb is used. Consider making the QDQ-specific args optional and provide a plain fallback rotation. Also silence the intentionally-unused position_ids for lint.

🛠️ Proposed fix (backward‑compatible and lint‑clean)
-def rotate_half(
-    x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver
-):
+def rotate_half(
+    x,
+    _x_observer=None,
+    x2_neg_fake_quant: Optional[ActivationQDQ] = None,
+    concat_observer: Optional[ConcatObserver] = None,
+):
     """Rotates half the hidden dims of the input."""
     x1 = x[..., : x.shape[-1] // 2]
     x2 = x[..., x.shape[-1] // 2 :]
-    return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1))
+    if x2_neg_fake_quant is not None and concat_observer is not None:
+        return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1))
+    return torch.cat((-x2, x1), dim=-1)

-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):  # noqa: ARG001
pymllm/backends/qualcomm/transformers/llama/modeling_llama.py-659-688 (1)

659-688: Fix RoPE cache indexing for batch>1 and avoid assert for required kwargs.

position_ids.squeeze(0) assumes batch size 1 and will produce incorrect shapes when batch>1. Also, assert is stripped in optimized mode; raise a real error for missing mllm_qualcomm_max_length.

🛠️ Suggested fix
         if self.mllm_max_sin_embedding is None and self.mllm_max_cos_embedding is None:
             mllm_qualcomm_max_length = kwargs.get("mllm_qualcomm_max_length", None)
-            assert mllm_qualcomm_max_length is not None
+            if mllm_qualcomm_max_length is None:
+                raise ValueError(
+                    "mllm_qualcomm_max_length is required to build the RoPE cache."
+                )
@@
-        position_embeddings = (
-            self.mllm_max_cos_embedding[:, position_ids.squeeze(0), :],
-            self.mllm_max_sin_embedding[:, position_ids.squeeze(0), :],
-        )
+        cos_cache = self.mllm_max_cos_embedding.squeeze(0)
+        sin_cache = self.mllm_max_sin_embedding.squeeze(0)
+        position_embeddings = (cos_cache[position_ids], sin_cache[position_ids])
As per coding guidelines: validate inputs for public APIs and critical internal functions.
pymllm/backends/qualcomm/transformers/llama/runner.py-197-207 (1)

197-207: Guard against missing CUDA (current code hard‑fails on CPU‑only hosts).

The unconditional self.model.cuda() call will fail on CPU-only machines without a clear error message. Qualcomm quantization appears to require CUDA, so validate availability upfront and fail with a descriptive error.

Suggested fix
+        if not torch.cuda.is_available():
+            raise RuntimeError("CUDA is required for Qualcomm Llama quantization.")
         self.model.cuda()

Note: This pattern appears in other Qualcomm quantizer classes (Qwen2, Qwen3) and should be addressed consistently. Per coding guidelines: ensure functions that can fail raise appropriate errors.

pymllm/backends/qualcomm/transformers/llama/runner.py-275-282 (1)

275-282: Use use_streaming=True instead of hardcoding trust_remote_code=True to avoid security and resource issues.

The code comment claims streaming is enabled, but no streaming parameter is passed to MsDataset.load(). This can trigger massive downloads without streaming. Additionally, trust_remote_code=True allows arbitrary code execution from the remote dataset repository—it should be an explicit opt-in parameter, not hardcoded.

🛠️ Suggested update
-    def calibrate(self, num_samples=64, max_seq_length=512):
+    def calibrate(
+        self,
+        num_samples=64,
+        max_seq_length=512,
+        use_streaming: bool = True,
+        trust_remote_code: bool = False,
+    ):
@@
         dataset = MsDataset.load(
             "modelscope/wikitext",
             subset_name="wikitext-103-v1",
             split="train",
+            use_streaming=use_streaming,
+            trust_remote_code=trust_remote_code,
         )
pymllm/backends/qualcomm/transformers/llama/runner.py-237-249 (1)

237-249: Add validation to handle prompts exceeding mllm_qualcomm_max_length.

The infer() method will fail with an unhelpful transformers error when max_new_tokens becomes ≤0. Validate the computed value and raise a user-friendly error before calling generate(), as per the coding guidelines for public API input validation.

🛠️ Suggested guard
-        generated_ids = self.model.generate(
+        max_new = (
+            self.mllm_qualcomm_max_length
+            - len(model_inputs.input_ids[0])
+            - 1
+        )
+        if max_new <= 0:
+            raise ValueError("Prompt length exceeds max_length; shorten the prompt.")
+        generated_ids = self.model.generate(
             **model_inputs,
-            max_new_tokens=self.mllm_qualcomm_max_length
-            - len(model_inputs.input_ids[0])
-            - 1,
+            max_new_tokens=max_new,
             do_sample=False,
             temperature=None,
             top_p=None,
             top_k=None,
         )
🟡 Minor comments (18)
pymllm/backends/qualcomm/transformers/qwen2/runner.py-278-285 (1)

278-285: Remove Chinese characters from comment and clarify streaming parameter usage.

Line 279 contains a mixed-language comment with Chinese characters ("几十G"). Comments should be in English for consistency and maintainability. Additionally, the comment references streaming=True, but the correct parameter name in the ModelScope API is use_streaming (not streaming). More importantly, the MsDataset.load() call doesn't use either parameter, making the comment misleading. Either implement use_streaming=True if streaming is desired, or update the comment to reflect the actual behavior.

mllm/backends/qnn/aot/passes/PTQPass.cpp-49-49 (1)

49-49: Fix the FIXME tag format.

Please use the required FIXME: prefix to comply with the comment format rule.

🔧 Proposed fix
-      // FIXME weight maybe error, Check qnn eats int8 or uint8. Here weight using int8 to store int4.
+      // FIXME: weight maybe error, Check qnn eats int8 or uint8. Here weight using int8 to store int4.

As per coding guidelines, TODO and FIXME comments must be written as 'TODO:' or 'FIXME:' followed by UTF-8 text that adheres to character set rules.

mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp-6-15 (1)

6-15: Replace magic number with named constant and use steady_clock for elapsed time measurement.

The code uses high_resolution_clock for measuring elapsed time (lines 72–127), but steady_clock is the correct choice—high_resolution_clock is not guaranteed to be monotonic and may jump backward or forward. Additionally, the magic number 1000000.0 appears twice in the TPS calculations and should be a named constant per coding guidelines.

The fmt library is already properly linked via MllmRT (public dependency), so no build changes are needed.

♻️ Suggested refactor
+  using Clock = std::chrono::steady_clock;
+  constexpr double kMicrosPerSec = 1e6;
@@
-  std::chrono::high_resolution_clock::time_point prefill_start, prefill_end;
-  if (perf) { prefill_start = std::chrono::high_resolution_clock::now(); }
+  Clock::time_point prefill_start, prefill_end;
+  if (perf) { prefill_start = Clock::now(); }
@@
-  std::chrono::high_resolution_clock::time_point decode_start, decode_end;
-  if (perf) { decode_start = std::chrono::high_resolution_clock::now(); }
+  Clock::time_point decode_start, decode_end;
+  if (perf) { decode_start = Clock::now(); }
@@
-    if (prefill_duration > 0) { prefill_tps = (double)prefill_token_count / (prefill_duration / 1000000.0); }
+    if (prefill_duration > 0) { prefill_tps = (double)prefill_token_count / (prefill_duration / kMicrosPerSec); }
@@
-    if (decode_duration > 0 && generated_count > 0) { decode_tps = (double)generated_count / (decode_duration / 1000000.0); }
+    if (decode_duration > 0 && generated_count > 0) { decode_tps = (double)generated_count / (decode_duration / kMicrosPerSec); }
examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp-620-620 (1)

620-620: Storing cfg as a const reference risks dangling reference.

Same issue as in the non-SHA header - if the config object is destroyed before the model, this becomes a dangling reference.

examples/qwen2_qnn_aot/compile_sha.cpp-39-43 (1)

39-43: Unreachable code after MLLM_ERROR_EXIT.

MLLM_ERROR_EXIT (based on the name) likely terminates the program, making lines 41-42 unreachable. Either remove the unreachable code or change the error handling approach.

♻️ Proposed fix
   if (!qnn_aot_cfg_files.isSet()) {
-    MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided");
-    Argparse::printHelp();
-    return -1;
+    Argparse::printHelp();
+    MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided");
   }
examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp-502-502 (1)

502-502: Storing cfg as a const reference risks dangling reference.

If the Qwen3Config object passed to the constructor is destroyed before Qwen2ForCausalLM, subsequent access to cfg will be undefined behavior. Consider storing by value or using std::shared_ptr.

🛡️ Proposed fix - store by value
-  const qwen3::Qwen3Config& cfg;
+  qwen3::Qwen3Config cfg;

And update the constructor to copy:

-  explicit Qwen2ForCausalLM(const qwen3::Qwen3Config& cfg) : cfg(cfg) {
+  explicit Qwen2ForCausalLM(const qwen3::Qwen3Config& config) : cfg(config) {
examples/qwen2_qnn_aot/aot_run.cpp-42-48 (1)

42-48: Tokenization result is discarded - appears to be debug/test code.

The tokenizer converts the prompt at line 42, but then immediately overwrites input_tensor["sequence"] with an arange tensor at line 44. The actual tokenized "hello" prompt is never used. Combined with the debug prints at lines 47-48, this appears to be debug/test code.

If this is intentional for benchmarking/testing with synthetic input, consider:

  1. Adding a command-line flag to choose between real tokenization and synthetic input
  2. Adding a comment explaining the intent
  3. Removing the unused tokenization call
♻️ Proposed fix if tokenization should be used
   auto tokenizer = mllm::models::qwen3::Qwen3Tokenizer(tokenizer_path.get());

-  auto input_tensor = tokenizer.convertMessage({.prompt = "hello"});
-
-  input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1});
+  // Use actual tokenized input or synthetic sequence based on use case
+  auto input_tensor = tokenizer.convertMessage({.prompt = "hello"});
+  // Pad/truncate sequence to desired length if needed

-  // DBG:
-  mllm::print(input_tensor["sequence"].shape());
-  mllm::print(input_tensor["sequence"]);
examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp-634-634 (1)

634-634: Empty forward() implementation returns empty map.

The forward method returns an empty ARGenerationOutputPast. If this class is used for inference (not just tracing), callers may receive unexpected empty results. Consider adding a comment explaining this is trace-only, or implement the forward pass.

examples/qwen3_qnn_aot/compile_sha.cpp-39-43 (1)

39-43: Unreachable code after MLLM_ERROR_EXIT.

Same issue as in compile.cpp - lines 41-42 are unreachable after MLLM_ERROR_EXIT.

🔧 Proposed fix
   if (!qnn_aot_cfg_files.isSet()) {
+    Argparse::printHelp();
     MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided");
-    Argparse::printHelp();
-    return -1;
   }
examples/qwen2_qnn_aot/compile.cpp-31-35 (1)

31-35: Unreachable code after MLLM_ERROR_EXIT.

MLLM_ERROR_EXIT likely terminates the program, making lines 33-34 unreachable. Either remove the dead code or restructure if you intend to show help before exiting.

🔧 Proposed fix
   if (!qnn_aot_cfg_files.isSet()) {
+    Argparse::printHelp();
     MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided");
-    Argparse::printHelp();
-    return -1;
   }
examples/llama_qnn_aot/aot_run.cpp-41-45 (1)

41-45: Tokenizer comment doesn’t match implementation.

The comment says “Qwen3 tokenizer as a placeholder,” but the code uses TinyLlamaTokenizer. Update the comment or swap to the intended tokenizer.

📝 Suggested comment fix
-  // Note: Using Qwen3 tokenizer as a placeholder.
+  // Note: Using TinyLlama tokenizer as a placeholder.
examples/llama_qnn_aot/compile_sha.cpp-9-10 (1)

9-10: Usage comment doesn’t match built binary name.

CMake defines mllm-llama-aot-c-sha, not compile_sha. Update the usage comment to avoid confusion.

📝 Suggested comment fix
-//   ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json
+//   ./mllm-llama-aot-c-sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json
examples/llama_qnn_aot/compile_sha.cpp-23-43 (1)

23-43: Mark model_path and model_cfg_path as required arguments.

Both model_path and model_cfg_path are used immediately after parsing (lines 45, 48) without validation. Marking them as required ensures clear error messages at argument parsing time rather than failing later with unclear errors from the constructor or load function. This follows the existing pattern already used for qnn_aot_cfg_files (line 39).

Proposed fix
-  auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model file path.");
-  auto& model_cfg_path = Argparse::add<std::string>("-c|--config").help("Model config file path.");
+  auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model file path.").required(true);
+  auto& model_cfg_path = Argparse::add<std::string>("-c|--config").help("Model config file path.").required(true);
examples/llama_qnn_aot/CMakeLists.txt-1-14 (1)

1-14: Improve code consistency: gate all QNN AOT targets uniformly.

All targets in this file link MllmQNNBackend and are only relevant when QNN AOT is enabled. While the parent examples/CMakeLists.txt already gates the entire subdirectory, the internal inconsistency—where compile targets are gated but the runner is not—reduces code clarity. For maintainability, move mllm-llama-aot-runner inside the if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) block to treat all targets uniformly.

🔧 Proposed fix
 if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE)
   add_executable(mllm-llama-aot-c compile.cpp)
   target_link_libraries(mllm-llama-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend)
   target_include_directories(mllm-llama-aot-c PRIVATE ${MLLM_INCLUDE_DIR})
 
   add_executable(mllm-llama-aot-c-sha compile_sha.cpp)
   target_link_libraries(mllm-llama-aot-c-sha PRIVATE MllmRT MllmCPUBackend MllmQNNBackend)
   target_include_directories(mllm-llama-aot-c-sha PRIVATE ${MLLM_INCLUDE_DIR})
+
+  add_executable(mllm-llama-aot-runner aot_run.cpp)
+  target_link_libraries(mllm-llama-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend)
+  target_include_directories(mllm-llama-aot-runner PRIVATE ${MLLM_INCLUDE_DIR})
 endif()
-
-add_executable(mllm-llama-aot-runner aot_run.cpp)
-target_link_libraries(mllm-llama-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend)
-target_include_directories(mllm-llama-aot-runner PRIVATE ${MLLM_INCLUDE_DIR})
examples/llama_qnn_aot/compile.cpp-15-35 (1)

15-35: Add explicit validation for required model and config parameters.

model_path and model_cfg_path are used unconditionally at lines 37 and 39 but lack validation. Without them, the program fails with unclear I/O errors. Mark them as required to match the pattern used for qnn_aot_cfg_files and provide clear error messaging.

Proposed fix
-  auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model file path.");
-  auto& model_cfg_path = Argparse::add<std::string>("-c|--config").help("Model config file path.");
+  auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model file path.").required(true);
+  auto& model_cfg_path = Argparse::add<std::string>("-c|--config").help("Model config file path.").required(true);
pymllm/backends/qualcomm/transformers/llama/modeling_llama.py-231-240 (1)

231-240: Silence unused kwargs to keep Ruff clean.

🔧 Minimal fix
-    **kwargs: Unpack[TransformersKwargs],
+    **_kwargs: Unpack[TransformersKwargs],
pymllm/backends/qualcomm/transformers/llama/modeling_llama.py-370-379 (1)

370-379: Rename unused **kwargs to avoid lint noise.

🔧 Minimal fix
-        **kwargs: Unpack[TransformersKwargs],
+        **_kwargs: Unpack[TransformersKwargs],
pymllm/backends/qualcomm/transformers/llama/runner.py-52-56 (1)

52-56: Avoid blind except Exception or explicitly justify it.

If the broad catch is intentional, add a # noqa: BLE001 to make the intent explicit (and consider logging the module name for debugging).

🔧 Minimal lint‑clean tweak
-            except Exception as e:
+            except Exception as e:  # noqa: BLE001

Comment on lines +539 to +543
if (cfg.tie_word_embeddings) {
// NOTE:
// model.lm_head.weight is quantization weights of model.embed_tokens.weight
lm_head_ = reg<nn::Conv2D>("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

lm_head_ used unconditionally but only registered when tie_word_embeddings is true.

At lines 539-543, lm_head_ is only registered if cfg.tie_word_embeddings is true. However, at line 604, lm_head_() is called unconditionally in trace(). If tie_word_embeddings is false, lm_head_ will be uninitialized, likely causing a crash or undefined behavior.

🐛 Proposed fix

Either always register lm_head_:

-    if (cfg.tie_word_embeddings) {
-      // NOTE:
-      // model.lm_head.weight is quantization weights of model.embed_tokens.weight
-      lm_head_ = reg<nn::Conv2D>("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY);
-    }
+    // Always register lm_head for causal LM
+    lm_head_ = reg<nn::Conv2D>("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY);

Or guard the usage in trace():

if (tie_word_embeddings_) {
  sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq"));
  sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq");
}

Also applies to: 604-604

🤖 Prompt for AI Agents
In `@examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp` around lines 539 -
543, lm_head_ is registered only when cfg.tie_word_embeddings is true but is
invoked unconditionally in trace(), which can lead to use of an uninitialized
module; fix by either always registering lm_head_ (register lm_head_ regardless
of cfg.tie_word_embeddings so lm_head_ is valid) or by guarding the call in
trace() with the tie flag (use tie_word_embeddings_ or cfg.tie_word_embeddings
to wrap the lm_head_() invocation so it’s only called when registered), and
ensure any related QDQ calls are likewise guarded or present only when lm_head_
exists.

Comment on lines +415 to +419
if (cfg.tie_word_embeddings) {
// NOTE:
// model.lm_head.weight is quantization weights of model.embed_tokens.weight
lm_head_ = reg<nn::Conv2D>("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

lm_head_ used unconditionally but only registered conditionally.

At lines 415-419, lm_head_ is only registered if cfg.tie_word_embeddings is true. However, at line 486, lm_head_() is called unconditionally. This will cause undefined behavior when tie_word_embeddings is false.

🐛 Proposed fix

Always register lm_head_:

-    if (cfg.tie_word_embeddings) {
-      // NOTE:
-      // model.lm_head.weight is quantization weights of model.embed_tokens.weight
-      lm_head_ = reg<nn::Conv2D>("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY);
-    }
+    // Always register lm_head for causal LM
+    lm_head_ = reg<nn::Conv2D>("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY);

Also applies to: 486-486

🤖 Prompt for AI Agents
In `@examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp` around lines 415 - 419,
The field lm_head_ is only registered inside the if (cfg.tie_word_embeddings)
branch but later invoked unconditionally (lm_head_()), causing undefined
behavior when tie_word_embeddings is false; always register lm_head_ by moving
or duplicating the reg<nn::Conv2D>("lm_head", cfg.hidden_size, cfg.vocab_size,
CONV2D_PROPERTY) call out of the conditional so lm_head_ is constructed
regardless of cfg.tie_word_embeddings, and keep the existing comment/logic that
when tie_word_embeddings is true the lm_head_ weight corresponds to
model.embed_tokens.weight (adjust any weight-tying code elsewhere to handle the
tied vs untied cases).

Comment on lines +90 to +123
def rotate_half(
x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver
):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1))


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py | sed -n '85,125p'

Repository: UbiquitousLearning/mllm

Length of output: 2310


🏁 Script executed:

# Search for all calls to rotate_half
rg 'rotate_half' pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py

Repository: UbiquitousLearning/mllm

Length of output: 239


🏁 Script executed:

# Check imports at the top of the file
head -50 pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py

Repository: UbiquitousLearning/mllm

Length of output: 1908


🏁 Script executed:

# Search for all rotate_half calls with more context
rg -B2 -A2 'rotate_half\(' pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py

Repository: UbiquitousLearning/mllm

Length of output: 875


🏁 Script executed:

# Check if there are any other files using rotate_half
rg 'rotate_half' pymllm/

Repository: UbiquitousLearning/mllm

Length of output: 1515


🏁 Script executed:

# Get complete context for rotate_half calls in the file
rg -B5 -A5 'rotate_half\(' pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py | head -80

Repository: UbiquitousLearning/mllm

Length of output: 1742


apply_rotary_pos_emb calls rotate_half with incorrect arguments.

The rotate_half function (lines 90–96) requires four parameters: x, x_observer, x2_neg_fake_quant, and concat_observer. However, apply_rotary_pos_emb (lines 121–122) calls it with only x (rotate_half(q) and rotate_half(k)), which will raise a TypeError at runtime.

Make the quantization parameters optional with a fallback to non-quantized rotation:

Proposed fix
def rotate_half(
-    x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver
+    x,
+    x_observer=None,
+    x2_neg_fake_quant: Optional[ActivationQDQ] = None,
+    concat_observer: Optional[ConcatObserver] = None,
):
     """Rotates half the hidden dims of the input."""
+    del x_observer  # kept for signature compatibility
     x1 = x[..., : x.shape[-1] // 2]
     x2 = x[..., x.shape[-1] // 2 :]
-    return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1))
+    if x2_neg_fake_quant is None or concat_observer is None:
+        return torch.cat((-x2, x1), dim=-1)
+    return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1))
🧰 Tools
🪛 Ruff (0.14.14)

91-91: Unused function argument: x_observer

(ARG001)


99-99: Unused function argument: position_ids

(ARG001)

🤖 Prompt for AI Agents
In `@pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py` around lines
90 - 123, The rotate_half call in apply_rotary_pos_emb fails because rotate_half
currently requires x_observer, x2_neg_fake_quant and concat_observer; change
rotate_half signature to make those three parameters optional (e.g., default to
None) and update its body to branch: if x2_neg_fake_quant and concat_observer
are provided, perform the quantized path using
x_observer/x2_neg_fake_quant/concat_observer exactly as before; otherwise
perform the plain rotation (slice x into halves, negate second half, concat) and
return the non-quantized tensor. Keep apply_rotary_pos_emb unchanged (calling
rotate_half(q)/rotate_half(k)) so it falls back to the non-quantized behavior
when observers/fake-quantizers are not supplied.

Comment on lines +31 to +35
parser.add_argument(
"--output_dir",
type=str,
help="Directory to save the quantized model",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing required=True for --output_dir argument.

The --output_dir argument has no default value and is not marked as required. If a user omits this flag, args.output_dir will be None, causing os.makedirs(args.output_dir, ...) on line 50 to raise a TypeError.

🐛 Proposed fix
     parser.add_argument(
         "--output_dir",
         type=str,
+        required=True,
         help="Directory to save the quantized model",
     )
🤖 Prompt for AI Agents
In `@pymllm/backends/qualcomm/transformers/qwen2/train.py` around lines 31 - 35,
The --output_dir CLI argument is not marked required which can lead to
args.output_dir being None and causing os.makedirs(args.output_dir,
exist_ok=True) in train.py to raise an error; update the parser.add_argument
call for "--output_dir" (the one creating the output_dir flag) to include
required=True (and optionally add a helpful help message) so args.output_dir is
always set before the os.makedirs(args.output_dir, ...) call.

@chenghuaWang chenghuaWang merged commit fe6c481 into UbiquitousLearning:main Jan 29, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants