feat(qwen3): Add configuration files and enhance Qwen3 model with layer indexing and quantization improvements.#611
Conversation
…er indexing and quantization improvements. Introduce new JSON configurations for 1.7B and 4B models, and update model architecture to support layer-specific operations and weight management.
…e graph structure, including quantization specifications and layer operations. This implementation enhances model performance and supports advanced quantization techniques.
…en3Text class to streamline model input handling.
📝 WalkthroughWalkthroughThe PR introduces Qwen3 4B model configurations for QNN AOT compilation and enhances the quantization pipeline with per-layer QDQ conditional logic, embedding weight synchronization for tied embeddings, and scale/zero-point recomputation with concat observer validation utilities. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant Model as Qwen3Model
participant Calibrator as Calibration
participant Quantizer as Qwen3Quantizer
participant FQ as FakeQuantize
participant Observer as ConcatObserver
participant Converter as AOT Converter
Trainer->>Quantizer: Initialize (tie_word_embeddings)
Quantizer->>Model: copy_lm_head_weight_from_embed_tokens()
Quantizer->>Model: freeze_qwen3_embed_tokens_weight()
Trainer->>Calibrator: Run calibration on Model
Calibrator->>Model: Forward pass (collect activation ranges)
Trainer->>Quantizer: Enable fake quantization
Trainer->>Quantizer: recompute_scale_zp()
Quantizer->>FQ: Refresh scale/zero_point per layer
FQ-->>Quantizer: Updated parameters
Trainer->>Quantizer: validate_concat_observer()
Quantizer->>Observer: Audit input observer consistency
Observer-->>Quantizer: Per-observer metrics & mismatches
Trainer->>Converter: Convert to QNN AOT
Converter->>Model: Process per-layer quantized tensors
Converter-->>Trainer: Compiled artifact
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp (2)
288-289:layer_idx_is uninitialized in the constructor.
layer_idx_is declared but never initialized withinQwen3Attention's constructor. It's set externally at line 347 inQwen3Text. If the member is accessed before external initialization (e.g., during debugging or if the initialization order changes), this could lead to undefined behavior.Consider initializing it in the constructor or using a default value.
🔧 Suggested fix
public: Qwen3Attention() = default; - Qwen3Attention(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + Qwen3Attention(const std::string& name, const Qwen3Config& cfg, int layer_idx = -1) : nn::Module(name), layer_idx_(layer_idx) { hidden_size_ = cfg.hidden_size;
486-486: Storing a reference tocfgmay lead to dangling reference.
const Qwen3Config& cfgis stored as a member, but if the originalQwen3Configobject passed to the constructor is destroyed or goes out of scope, this reference becomes dangling, leading to undefined behavior when accessed intrace().Consider storing by value or using
std::shared_ptrif the config is expensive to copy.🔧 Suggested fix
- const Qwen3Config& cfg; + Qwen3Config cfg;
🤖 Fix all issues with AI agents
In `@pymllm/backends/qualcomm/transformers/qwen3/runner.py`:
- Around line 19-82: The debug print in recompute_scale_zp incorrectly
references module.scale; update the log to reference the FakeQuantize buffer by
using module.fake_quant.scale instead (i.e., change the f-string in the loop
over module.fake_quant.named_parameters() to print {module.fake_quant.scale});
ensure the print remains inside the loop that checks for value is
module.fake_quant.scale so it logs the correct tensor from FakeQuantize.
- Around line 100-138: The code assumes per-tensor quantization by calling
.item() on scale and zp from observer.calculate_qparams() (used in the
ConcatObserver logging and mismatch messages); make this defensive: when
collecting and printing scales_zps and when formatting mismatch messages in the
loop, detect if scale.numel() == 1 and use .item(), otherwise convert to a list
(scale.tolist(), zp.tolist()) and format accordingly (e.g., show full list or
summarized stats). Also ensure comparisons still work for multi-element qparams
by keeping torch.allclose(ref_scale, scale, ...) and torch.equal(ref_zp, zp)
(they already support multi-element tensors), and add a short comment near
input_observers / ConcatObserver indicating the code supports both per-tensor
and per-channel qparams.
🧹 Nitpick comments (3)
pymllm/backends/qualcomm/transformers/qwen3/runner.py (1)
52-57: Consider using the logging module instead of print for error handling.The broad
except Exceptionis acceptable here given the variety of observer implementations, but usinglogging.warningorlogging.debuginstead ofprint(e)would provide better control over verbosity and avoid cluttering stdout in production.Proposed improvement
+import logging + +logger = logging.getLogger(__name__) + # In recompute_scale_zp function: try: scale, zero_point = observer.calculate_qparams() except Exception as e: # Some special Observers (e.g., FixedQParams) may not support recomputation or behave differently, safely skip - print(e) + logger.debug("Skipping observer recomputation: %s", e) returnexamples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp (2)
20-30: Redundant initialization ofscale_nameandzp_name.Lines 21-22 initialize
scale_nameandzp_name, but the if-else block at lines 24-30 always overwrites these values. Lines 28-29 in the else branch are identical to lines 21-22.♻️ Suggested simplification
Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { - std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; - std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; - - if (m->getModuleName().empty()) { - scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; - zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; - } else { - scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; - zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; - } + const auto& modName = m->getModuleName(); + std::string prefix = modName.empty() ? "" : modName + "."; + std::string scale_name = prefix + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = prefix + qdq_name_in_pytorch + ".fake_quant.zero_point";
301-307:layer_idxis not propagated toself_attn_.The constructor receives
layer_idxand stores it inlayer_idx_, but doesn't pass it toself_attn_during registration. Instead,self_attn_.layer_idx_is set externally inQwen3Textat line 347. This creates a fragmented initialization pattern where the decoder knows its layer index but relies on an external caller to set the same index on its child attention module.Consider propagating the index during construction for encapsulation, or document this dependency clearly.
♻️ Suggested fix to propagate layer_idx
If
Qwen3Attentionconstructor is updated to acceptlayer_idx:Qwen3Decoder(const std::string& name, const Qwen3Config& cfg, int layer_idx) : nn::Module(name) { layer_idx_ = layer_idx; - self_attn_ = reg<Qwen3Attention>("self_attn", cfg); + self_attn_ = reg<Qwen3Attention>("self_attn", cfg, layer_idx); mlp_ = reg<Qwen3MLP>("mlp", cfg);Then remove line 347 in
Qwen3Text:decode_blocks_ = reg<nn::ModuleListWithIdx<Qwen3Decoder>>("layers", cfg.num_hidden_layers, cfg); - for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } norm_ = reg<nn::RMSNorm>("norm", cfg.rms_norm_eps);
| def recompute_scale_zp(module): | ||
| """ | ||
| Callback function: Used to forcefully refresh scale and zero_point of all FakeQuantize modules after calibration. | ||
|
|
||
| Problem solved: | ||
| When using ConcatObserver, min/max may be updated during forward pass, | ||
| but at the end of forward, the scale/zp stored in FakeQuantize's internal buffer are still computed from old min/max. | ||
| This function forces a calculate_qparams call to sync the latest parameters to the buffer. | ||
|
|
||
| Usage: | ||
| model.apply(recompute_scale_zp) | ||
| """ | ||
|
|
||
| # We mainly focus on FakeQuantize modules since they store the scale/zero_point buffers | ||
| # Note: model.apply recursively traverses all submodules, so self.fake_quant inside ActivationQDQ will also be visited | ||
| if isinstance(module, ActivationQDQ): | ||
| observer = module.fake_quant.activation_post_process | ||
|
|
||
| # 2. Check if observer is valid and contains statistics | ||
| # We only care about MinMaxObserver or MovingAverageMinMaxObserver that have min_val/max_val | ||
| if hasattr(observer, "min_val") and hasattr(observer, "max_val"): | ||
| # 3. Check if data is initialized | ||
| # If min_val is still the initial inf, this layer hasn't processed data, skip to avoid errors | ||
| if observer.min_val.numel() == 0 or observer.max_val.numel() == 0: | ||
| return | ||
| if ( | ||
| torch.isinf(observer.min_val).any() | ||
| or torch.isinf(observer.max_val).any() | ||
| ): | ||
| return | ||
|
|
||
| # 4. Recompute Scale and Zero Point | ||
| # calculate_qparams reads the current min_val/max_val from observer (may have been modified by ConcatObserver) | ||
| try: | ||
| scale, zero_point = observer.calculate_qparams() | ||
| except Exception as e: | ||
| # Some special Observers (e.g., FixedQParams) may not support recomputation or behave differently, safely skip | ||
| print(e) | ||
| return | ||
|
|
||
| # 5. Force overwrite the computed results to FakeQuantize's Buffer | ||
| # Use copy_ to keep reference unchanged, ensuring the new values are used during export | ||
| if ( | ||
| hasattr(module.fake_quant, "scale") | ||
| and module.fake_quant.scale is not None | ||
| ): | ||
| # Ensure dimension match (handle per-channel vs per-tensor) | ||
| if module.fake_quant.scale.shape != scale.shape: | ||
| module.fake_quant.scale.resize_(scale.shape) | ||
| module.fake_quant.scale.copy_(scale) | ||
| # Try to get the registered name of module scale from _parameters or _buffers | ||
| for key, value in module.fake_quant.named_parameters(): | ||
| if value is module.fake_quant.scale: | ||
| print(f"{module._get_name()}.{key}: {module.scale}") | ||
| break | ||
|
|
||
| if ( | ||
| hasattr(module.fake_quant, "zero_point") | ||
| and module.fake_quant.zero_point is not None | ||
| ): | ||
| if module.fake_quant.zero_point.shape != zero_point.shape: | ||
| module.fake_quant.zero_point.resize_(zero_point.shape) | ||
| module.fake_quant.zero_point.copy_(zero_point) | ||
|
|
There was a problem hiding this comment.
Bug: Incorrect attribute reference in debug print statement.
Line 72 prints module.scale but based on the surrounding code, it should be module.fake_quant.scale.
Proposed fix
for key, value in module.fake_quant.named_parameters():
if value is module.fake_quant.scale:
- print(f"{module._get_name()}.{key}: {module.scale}")
+ print(f"{module._get_name()}.{key}: {module.fake_quant.scale}")
break🧰 Tools
🪛 Ruff (0.14.13)
54-54: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In `@pymllm/backends/qualcomm/transformers/qwen3/runner.py` around lines 19 - 82,
The debug print in recompute_scale_zp incorrectly references module.scale;
update the log to reference the FakeQuantize buffer by using
module.fake_quant.scale instead (i.e., change the f-string in the loop over
module.fake_quant.named_parameters() to print {module.fake_quant.scale}); ensure
the print remains inside the loop that checks for value is
module.fake_quant.scale so it logs the correct tensor from FakeQuantize.
| # Collect scale and zero_point from all observers | ||
| scales_zps = [] | ||
| for i, observer in enumerate(input_observers): | ||
| try: | ||
| scale, zp = observer.calculate_qparams() | ||
| scales_zps.append(f"[{i}] s={scale.item():.8f} zp={zp.item()}") | ||
| except Exception: | ||
| scales_zps.append(f"[{i}] failed") | ||
|
|
||
| # Print one line: scale and zp of all inputs for each concat observer | ||
| print(f"ConcatObserver [{name}]: {' | '.join(scales_zps)}") | ||
|
|
||
| # Original validation logic | ||
| if len(input_observers) <= 1: | ||
| return | ||
|
|
||
| # Get scale and zero_point from the first observer as reference | ||
| first_observer = input_observers[0] | ||
| try: | ||
| ref_scale, ref_zp = first_observer.calculate_qparams() | ||
| except Exception: | ||
| return | ||
|
|
||
| # Check if all other observers have the same scale and zero_point | ||
| for i, observer in enumerate(input_observers[1:], start=1): | ||
| try: | ||
| scale, zp = observer.calculate_qparams() | ||
| except Exception: | ||
| results.append(f"Failed to calculate qparams for observer[{i}]") | ||
| continue | ||
|
|
||
| scale_match = torch.allclose(ref_scale, scale, rtol=1e-5, atol=1e-8) | ||
| zp_match = torch.equal(ref_zp, zp) | ||
|
|
||
| if not scale_match or not zp_match: | ||
| results.append( | ||
| f"observer[{i}] mismatch: ref_scale={ref_scale.item():.8f}, " | ||
| f"scale={scale.item():.8f}, ref_zp={ref_zp.item()}, zp={zp.item()}" | ||
| ) |
There was a problem hiding this comment.
Potential issue: .item() calls assume per-tensor quantization.
Lines 105 and 136-137 call .item() on scale and zp tensors. This will fail with a ValueError if the observer uses per-channel quantization (where scale/zp are multi-element tensors).
Given the ConcatObserver configuration in modeling_qwen3.py uses per_tensor_affine, this may be safe, but it's worth adding defensive handling or a comment clarifying the assumption.
Proposed defensive fix
for i, observer in enumerate(input_observers):
try:
scale, zp = observer.calculate_qparams()
- scales_zps.append(f"[{i}] s={scale.item():.8f} zp={zp.item()}")
+ if scale.numel() == 1:
+ scales_zps.append(f"[{i}] s={scale.item():.8f} zp={zp.item()}")
+ else:
+ scales_zps.append(f"[{i}] s={scale.tolist()} zp={zp.tolist()}")
except Exception:
scales_zps.append(f"[{i}] failed")🧰 Tools
🪛 Ruff (0.14.13)
106-106: Do not catch blind exception: Exception
(BLE001)
120-120: Do not catch blind exception: Exception
(BLE001)
127-127: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In `@pymllm/backends/qualcomm/transformers/qwen3/runner.py` around lines 100 -
138, The code assumes per-tensor quantization by calling .item() on scale and zp
from observer.calculate_qparams() (used in the ConcatObserver logging and
mismatch messages); make this defensive: when collecting and printing scales_zps
and when formatting mismatch messages in the loop, detect if scale.numel() == 1
and use .item(), otherwise convert to a list (scale.tolist(), zp.tolist()) and
format accordingly (e.g., show full list or summarized stats). Also ensure
comparisons still work for multi-element qparams by keeping
torch.allclose(ref_scale, scale, ...) and torch.equal(ref_zp, zp) (they already
support multi-element tensors), and add a short comment near input_observers /
ConcatObserver indicating the code supports both per-tensor and per-channel
qparams.
Summary by CodeRabbit
New Features
Chores
✏️ Tip: You can customize this high-level summary in your review settings.