diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index a724782a..ba1cdb22 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -132,7 +132,7 @@ Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorc } using vi32 = std::vector; -#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16 // Using Conv2D to replace Linear. // Conv2D Filter Weight is [1, 1, In, Out] diff --git a/examples/qwen3_qnn_aot/qnn_aot_cfg_1.7B.json b/examples/qwen3_qnn_aot/qnn_aot_cfg_1.7B.json index b20b4c22..83fc6d42 100644 --- a/examples/qwen3_qnn_aot/qnn_aot_cfg_1.7B.json +++ b/examples/qwen3_qnn_aot/qnn_aot_cfg_1.7B.json @@ -23,7 +23,7 @@ "method": "LPBQ", "sym": true, "precision": "w4a16", - "block_size": 32 + "block_size": 16 } }, "linear": { @@ -31,7 +31,7 @@ "method": "LPBQ", "sym": true, "precision": "w4a16", - "block_size": 32 + "block_size": 16 } }, "kv_cache": { diff --git a/mllm/backends/cpu/ops/Conv2DOp.cpp b/mllm/backends/cpu/ops/Conv2DOp.cpp index 1e071347..9c05429e 100644 --- a/mllm/backends/cpu/ops/Conv2DOp.cpp +++ b/mllm/backends/cpu/ops/Conv2DOp.cpp @@ -61,6 +61,7 @@ void CPUConv2DOp::load(const ParameterFile::ptr_t& ploader) { weight_ = packed_weight; break; } + case aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16: case aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32: case aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G64: { break; diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index b8b8e7b5..e7478105 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -172,6 +172,7 @@ void QnnAOTNodeTensor::setupComplexTensorQuantization(const ir::tensor::TensorVa break; } case ir::linalg::QuantizationSpecType::kLPBQ: { + MLLM_INFO("Solving LPBQ quantization for tensor: {}", v->tensor_.name()); // This LPBQ Type is for Conv2D Only !!! Linear has diff layout cmp with conv2d auto cfg = std::static_pointer_cast(quant_spec); @@ -182,7 +183,7 @@ void QnnAOTNodeTensor::setupComplexTensorQuantization(const ir::tensor::TensorVa MLLM_RT_ASSERT_EQ(num_scale_offsets, cfg->scale_level_1_fp.size(-1)); MLLM_RT_ASSERT_EQ(cfg->scale_level_0_int.dtype(), kUInt8); for (int i = 0; i < num_scale_offsets; ++i) { - scale_offsets[i].scale = cfg->scale_level_1_fp.at({0, 0, 0, i}); + scale_offsets[i].scale = cfg->scale_level_1_fp.at({i}); scale_offsets[i].offset = 0; } diff --git a/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp index 63b58e0b..69bd9357 100644 --- a/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp +++ b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp @@ -83,7 +83,7 @@ uint8_t LLM2QnnLoweringPass::run(const ir::node_ptr_t& op) { for (auto& region_op : model_op->getTopRegion()->ops()) { if (auto sub_graph_op = std::dynamic_pointer_cast(region_op)) { auto symbol_attr = sub_graph_op->getSymbolAttr(); - if (symbol_attr) { subgraphs[symbol_attr->str()] = sub_graph_op; } + if (symbol_attr && symbol_attr->str() != "init") { subgraphs[symbol_attr->str()] = sub_graph_op; } } } diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 60e95c9c..444abf57 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -317,8 +317,9 @@ bool LLMQuantRecipeConv2DPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr ir::linalg::QuantizationSpecLPBQ::ptr_t weight_quant_spec = nullptr; if (precision == "w4a16") { + // HWIO weight_quant_spec = - ir::linalg::QuantizationSpecLPBQ::create(-8, 7, block_size, 0, 4, kUInt4, kFloat32, Tensor::nil(), Tensor::nil()); + ir::linalg::QuantizationSpecLPBQ::create(-7, 7, block_size, 3, 4, kInt4, kFloat32, Tensor::nil(), Tensor::nil()); // output sym int16 auto out_quant_spec = ir::linalg::QuantizationSpecAsymPerTensor::create(0, 65536 - 1, kUInt16, kFloat32, kInt32, diff --git a/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp b/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp index 5c2ff94f..4e0cf45e 100644 --- a/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp +++ b/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp @@ -197,7 +197,7 @@ uint8_t SplitLLMGraphPass::run(const ir::node_ptr_t& op) { if (name == "model." + std::to_string(i) + ".s" + std::to_string(__global_seq_len)) { matched = true; } if (name == "model") { matched = true; } } - if (!matched) { wvw.removeOp(sub_g_op); } + if (!matched && name != "init") { wvw.removeOp(sub_g_op); } return ir::IRWriter::WalkResult::WALK_CONTINUE; }); } diff --git a/mllm/backends/qnn/aot/visitor/Conv2D.cpp b/mllm/backends/qnn/aot/visitor/Conv2D.cpp index 172992e8..083a12b3 100644 --- a/mllm/backends/qnn/aot/visitor/Conv2D.cpp +++ b/mllm/backends/qnn/aot/visitor/Conv2D.cpp @@ -94,6 +94,16 @@ bool QnnAOTConv2DPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) qnn_op_node->emplaceParamTensor(pad_amount_param); } + // Add params: dilation + { + auto dilation_param = QNNParamTensorWrapper::create("dilation", base_op->getName() + ".dilation", QNN_DATATYPE_UINT_32, + std::vector{2}); + uint32_t* data = static_cast(dilation_param->alloc()); + data[0] = 1; + data[1] = 1; + qnn_op_node->emplaceParamTensor(dilation_param); + } + // Register this op node into one graph. env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); diff --git a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp index f0feee0c..bee78b1b 100644 --- a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp +++ b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp @@ -65,8 +65,7 @@ void PromptProcessor::init_io() { output_tensors_.reserve(1 + 2 * config_.num_layers); // 1. Logits - // DBG: - auto logits = Tensor::empty({1, 1, config_.ar_len, 2048}, kUInt16, kQNN).alloc(); + auto logits = Tensor::empty({1, 1, config_.ar_len, config_.vocab_size}, kUInt16, kQNN).alloc(); logits.setName("logits"); output_tensors_.push_back(logits); @@ -132,7 +131,7 @@ int64_t PromptProcessor::prefill(const std::vector& prompt_tokens, i prepare_io(prompt_tokens, processed_tokens, current_pos); - auto module_input = input_tensors_; + std::vector module_input = input_tensors_; output_tensors_ = (*module_)(module_input); int32_t n_update = chunk_size; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp index 846ce7c8..3bbd077d 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp @@ -1,9 +1,10 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. -#include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include #include + +#include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include "mllm/core/DataTypes.hpp" #include "mllm/core/DeviceTypes.hpp" #include "mllm/preprocessor/tokenizers/Unicode.hpp" diff --git a/mllm/core/aops/Conv2DOp.cpp b/mllm/core/aops/Conv2DOp.cpp index d6934e10..5a1853aa 100644 --- a/mllm/core/aops/Conv2DOp.cpp +++ b/mllm/core/aops/Conv2DOp.cpp @@ -78,7 +78,8 @@ void Conv2DOp::reshape(const std::vector& inputs, std::vector& o // CHECK if in Qualcomm DSP shape. Inputs is [N, H, W, C], Filter Weight is [N, H, In, Out] if (options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 - || options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G64) { + || options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G64 + || options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16) { in_channels = ishape[3]; in_height = ishape[1]; in_width = ishape[2]; @@ -112,7 +113,8 @@ void Conv2DOp::reshape(const std::vector& inputs, std::vector& o auto new_shape = std::vector{batch, out_channels, h_out, w_out}; if (options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 - || options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G64) { + || options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G64 + || options_.impl_type == Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16) { new_shape = std::vector{batch, h_out, w_out, out_channels}; } diff --git a/mllm/core/aops/Conv2DOp.hpp b/mllm/core/aops/Conv2DOp.hpp index 02466b3c..33c2e43e 100644 --- a/mllm/core/aops/Conv2DOp.hpp +++ b/mllm/core/aops/Conv2DOp.hpp @@ -12,6 +12,7 @@ enum class Conv2DOpImplType { kDefault = 0, // LPBQ + kQNN_LPBQ_w4a16o16_G16, kQNN_LPBQ_w4a16o16_G32, kQNN_LPBQ_w4a16o16_G64, }; @@ -28,7 +29,11 @@ struct Conv2DOpOptions : public BaseOpOptions { }; inline Conv2DOpImplType str2Conv2DOpImplType(const std::string& str) { - static const std::unordered_map map = {{"Default", Conv2DOpImplType::kDefault}}; + static const std::unordered_map map = { + {"Default", Conv2DOpImplType::kDefault}, + {"QNN_LPBQ_w4a16o16_G16", Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16}, + {"QNN_LPBQ_w4a16o16_G32", Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32}, + {"QNN_LPBQ_w4a16o16_G64", Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G64}}; auto it = map.find(str); if (it != map.end()) { return it->second; } @@ -38,7 +43,11 @@ inline Conv2DOpImplType str2Conv2DOpImplType(const std::string& str) { } inline std::string Conv2DOpImplType2Str(Conv2DOpImplType type) { - static const std::unordered_map map = {{Conv2DOpImplType::kDefault, "Default"}}; + static const std::unordered_map map = { + {Conv2DOpImplType::kDefault, "Default"}, + {Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16, "QNN_LPBQ_w4a16o16_G16"}, + {Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32, "QNN_LPBQ_w4a16o16_G32"}, + {Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G64, "QNN_LPBQ_w4a16o16_G64"}}; auto it = map.find(type); if (it != map.end()) return it->second; diff --git a/pymllm/backends/qualcomm/transformers/core/embedding.py b/pymllm/backends/qualcomm/transformers/core/embedding.py index 6af0b305..b581f23f 100644 --- a/pymllm/backends/qualcomm/transformers/core/embedding.py +++ b/pymllm/backends/qualcomm/transformers/core/embedding.py @@ -124,9 +124,12 @@ def freeze_weight(self): f"Class: {class_name}, Instance: {instance_class_name}, Weight Quantized: scale={self.weight_fake_quant.scale}, zp={self.weight_fake_quant.zero_point}" ) - def disable_quant(self): + def disable_fakequant(self): """Completely turn off quantization noise and return to floating point mode""" - self.weight_fake_quant.disable_fakequant() + self.weight_fake_quant.disable_fake_quant() + + def enable_fakequant(self): + self.weight_fake_quant.enable_fake_quant() def extra_repr(self): s = f"{self.num_embeddings}, {self.embedding_dim}" diff --git a/pymllm/backends/qualcomm/transformers/core/observer.py b/pymllm/backends/qualcomm/transformers/core/observer.py index 67a946b1..f5fd8bc4 100644 --- a/pymllm/backends/qualcomm/transformers/core/observer.py +++ b/pymllm/backends/qualcomm/transformers/core/observer.py @@ -1,5 +1,147 @@ import torch +from typing import Tuple from torchao.quantization.pt2e import UniformQuantizationObserverBase +from torchao.quantization.pt2e import FakeQuantize, MappingType, PerBlock +from torchao.quantization.pt2e._affine_quantization import ( + _get_reduction_params, + AffineQuantizedMinMaxObserver, + choose_qparams_affine_with_min_max, +) +from torchao.quantization.quant_primitives import _fake_quantize_affine + + +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +class PerBlockParamObserver(AffineQuantizedMinMaxObserver): + def __init__( + self, + dtype: torch.dtype, + block_size: torch.Size, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + **kwargs, + ): + super().__init__( + mapping_type=MappingType.SYMMETRIC, + target_dtype=dtype, + granularity=PerBlock, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + **kwargs, + ) + self.dtype = dtype + self.block_size = block_size + # TODO: expand this when QNN starts to support more configurations + self.bitwidth_of_scale = 4 + self.num_steps = 2**self.bitwidth_of_scale + self.calibrated = False + + def forward(self, input: torch.Tensor): + if input.numel() == 0 or self.calibrated: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims) + max_val = torch.amax(input_detached, dim=reduction_dims) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + assert self.min_val.shape == min_val.shape, ( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + assert self.max_val.shape == max_val.shape, ( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + self.calibrated = True + return input + + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr(self, "max_val"), ( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class PerBlockParamFakeQuantize(FakeQuantize): + def __init__( + self, + dtype: torch.dtype = torch.int8, + block_size: torch.Size = None, + quant_min: int = None, + quant_max: int = None, + eps: float = torch.finfo(torch.float32).eps, # noqa: B008 + **kwargs, + ): + super().__init__() + assert block_size is not None, ( + "block_size must be provided for per-block quantization" + ) + + self.activation_post_process = PerBlockParamObserver( + dtype=dtype, + block_size=block_size, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + **kwargs, + ) + self.dtype = dtype + self.block_size = block_size + self.quant_min = quant_min if quant_min is not None else torch.iinfo(dtype).min + self.quant_max = quant_max if quant_max is not None else torch.iinfo(dtype).max + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + + self.activation_post_process(x) + scale, zero_point = self.activation_post_process.calculate_qparams() + + return _fake_quantize_affine( + x, + self.block_size, + scale, + zero_point, + quant_dtype=self.dtype, + quant_min=self.quant_min, + quant_max=self.quant_max, + ) + + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.activation_post_process.calculate_qparams() + + def convert(self, model, observer_node): + self.activation_post_process.convert(model, observer_node) class ConcatObserver(UniformQuantizationObserverBase): diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/backends/qualcomm/transformers/core/qlinear.py index 255f52ff..d3bc8150 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/backends/qualcomm/transformers/core/qlinear.py @@ -2,6 +2,13 @@ import torch.nn as nn import torch.nn.functional as F from torch.ao.quantization import FakeQuantize, PerChannelMinMaxObserver +from pymllm.backends.qualcomm.transformers.core.observer import ( + PerBlockParamFakeQuantize, +) +from torchao.quantization.quant_primitives import ( + _quantize_affine, + _get_reduction_params, +) class QLinear(nn.Module): @@ -25,15 +32,18 @@ def freeze_weight(self): if self.weight_quant is not None: # Compatible with official FakeQuantize module if ( - isinstance(self.weight_quant, FakeQuantize) + isinstance(self.weight_quant, PerBlockParamFakeQuantize) and self.weight_quant is not None ): - _ = self.weight_quant(self.weight) + self.weight_quant.enable_observer() + self.weight_quant.activation_post_process(self.weight) + s, zp = self.weight_quant.activation_post_process.calculate_qparams() self.weight_quant.disable_observer() - s = self.weight_quant.scale + self.weight_quant.scale = s + self.weight_quant.zero_point = zp print( f"[{self.__class__.__name__}] Scale Shape: {list(s.shape)}, " - f"scale[:3]: {s.flatten()[:3].tolist()}" + f"scale[:3]: {s.flatten()[:3].tolist()}, zp: {zp.flatten()[:3].tolist()}" ) # Compatible with custom LPBQ logic elif hasattr(self.weight_quant, "freeze"): @@ -157,78 +167,24 @@ def convert_to_conv2d_deploy_hwio(self): ) -# --- 2. LPBQ (Double Quantization) Scheme --- -class DoubleQuantizer(nn.Module): - """ - Handles LPBQ double normalization logic to work like FakeQuantize - """ - - def __init__(self, block_size=64): - super().__init__() - self.block_size = block_size - self.register_buffer("is_frozen", torch.tensor(False)) - self.register_buffer("scale_2_fp32", None) - self.register_buffer("scale_1_uint4", None) - self.register_buffer("weight_q", None) - self.w_recon_cached = None # Cache dequantized weights for acceleration - - def freeze(self, w): - # Run complete double quantization and store in buffer - self.w_recon_cached = self.quantize_dequantize(w, save_buffers=True) - self.is_frozen = torch.tensor(True) - - def quantize_dequantize(self, w, save_buffers=False): - out_channels, in_channels = w.shape - # 1. Padding handling - pad_len = (self.block_size - in_channels % self.block_size) % self.block_size - if pad_len > 0: - w = F.pad(w, (0, pad_len), "constant", 0) - - w_reshaped = w.view(out_channels, -1, self.block_size) - - # Level 1: FP32 Scale - s1 = w_reshaped.abs().amax(dim=-1, keepdim=True) / 7.0 - s1 = s1.clamp(min=1e-8) - - # Level 2: Quantize S1 to Uint4 - s2 = s1.amax(dim=1, keepdim=True) / 15.0 - s2 = s2.clamp(min=1e-8) - s1_q = (s1 / s2).round().clamp(0, 15) - s1_recon = s1_q * s2 - - # Level 3: Quantize Weight to Int4 - w_q = (w_reshaped / s1_recon).round().clamp(-8, 7) - w_recon = w_q * s1_recon - - if save_buffers: - self.scale_2_fp32 = s2.detach() - self.scale_1_uint4 = s1_q.detach().to(torch.uint8) - self.weight_q = w_q.detach().to(torch.int8) - - # Restore shape - w_out = w_recon.view(out_channels, -1) - if pad_len > 0: - w_out = w_out[:, :-pad_len] - return w_out - - def forward(self, w): - if self.is_frozen: - # If frozen, directly return cached reconstructed weights (or real-time dequantization from Buffer) - if self.w_recon_cached is None: - # Logic to reconstruct from weight_q + scale_1 + scale_2 can be written here - pass - return ( - self.w_recon_cached - if self.w_recon_cached is not None - else self.quantize_dequantize(w) - ) - return self.quantize_dequantize(w) - - class QLinearLPBQ(QLinear): def __init__(self, in_features, out_features, bias=True, block_size=64): super().__init__(in_features, out_features, bias) - self.weight_quant = DoubleQuantizer(block_size) + self.block_size = [1, block_size] + self.weight_quant = PerBlockParamFakeQuantize( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + block_size=self.block_size, + eps=0.0001 / 65535, + ch_axis=0, + ) + + def enable_fakequant(self): + self.weight_quant.enable_fake_quant() + + def disable_fakequant(self): + self.weight_quant.disable_fake_quant() def forward(self, x): # Must use quantized weights w_q for computation @@ -236,80 +192,64 @@ def forward(self, x): return F.linear(x, w_q, self.bias) @torch.no_grad() - def convert_to_deploy(self): - if self.deploy_mode: - return - - del self.weight - self.register_buffer( - "weight", - self.weight_quant.weight_q.reshape(self.weight_quant.weight_q.shape[0], -1), - ) - self.register_buffer("scale1", self.weight_quant.scale_1_uint4) - self.register_buffer("scale2", self.weight_quant.scale_2_fp32) - del self.weight_quant - - self.deploy_mode = True + def convert_to_conv2d_deploy_hwio(self): + linear_scale = self.weight_quant.scale + linear_zero_point = self.weight_quant.zero_point print( - f"[{self.__class__.__name__}] Converted to deploy. Original float weight removed." + "Original Linear Scale[:3]: , zp[:3]: ", + linear_scale.flatten()[:3].tolist(), + linear_zero_point.flatten()[:3].tolist(), ) - @torch.no_grad() - def convert_to_conv2d_deploy_hwio(self): - """ - Convert to deploy format with HWIO layout [1, 1, In, Out]. - This format is commonly used by convolution-based inference engines. - """ - if self.deploy_mode: - return - if not self.weight_quant.is_frozen: - self.freeze_weight() - - # Step 1: Extract quantized weights in block format - # Shape: [Out, Blocks, BlockSize] - w_q_blocks = self.weight_quant.weight_q - - # Step 2: Flatten and remove padding - w_q_flat = w_q_blocks.view(self.out_features, -1) # Shape: [Out, In_Padded] - if w_q_flat.shape[1] > self.in_features: - w_q_flat = w_q_flat[:, : self.in_features] - - # Step 3: Critical step - Transpose weights - # [Out, In] -> [In, Out] - w_transposed = w_q_flat.t().contiguous() - - # Step 4: Reshape to HWIO [1, 1, In, Out] - w_hwio = w_transposed.view(1, 1, self.in_features, self.out_features) - - # Step 5: Process LPBQ Scales - # Scale2 (Per-Channel): Original [Out, 1, 1] - # Target: [1, 1, 1, Out] - s2 = self.weight_quant.scale_2_fp32 - s2_hwio = s2.flatten().view(1, 1, 1, self.out_features) - - # Scale1 (Per-Block): Original [Out, n_blocks, 1] - # n_blocks corresponds to Input Channel blocking - # When weights are transposed, scale layout needs to match engine read order - # Assuming engine reads (1, 1, In, Out), Scale1 maintains block correspondence - # Transpose to [1, 1, n_blocks, Out] to logically match HWIO order - s1 = self.weight_quant.scale_1_uint4 # Shape: [Out, Blocks, 1] - s1_permuted = ( - s1.view(self.out_features, -1).t().contiguous() - ) # [Out, Blocks] -> [Blocks, Out] - s1_hwio = s1_permuted.view( - 1, 1, -1, self.out_features - ) # Shape: [1, 1, Blocks, Out] + # Convert weight to int4 (represent as int8) + assert self.weight.shape[-1] % self.block_size[1] == 0 + assert linear_zero_point.sum() == 0 + weight_int4 = _quantize_affine( + self.weight, + self.block_size, + linear_scale, + linear_zero_point, + torch.int32, + quant_min=-7, + quant_max=7, + ).to(torch.int8) + + # LPBQ Scale Quantization + # Quantize fp32 scale to uint4 scale + bitwidth_of_scale = 4 + num_channels = linear_scale.shape[0] # [O, I / block_size[1]] + num_steps = 2**bitwidth_of_scale + quant_scales_dtype = torch.uint8 + quantized_scales = [] + level_2_scales = [] + for ch in range(num_channels): + candidates = linear_scale[ch] + max_scale = candidates.reshape(1, -1).amax(dim=-1) / num_steps + q_scales = torch.clamp( + input=torch.round(input=candidates / max_scale), + min=1, + max=2**bitwidth_of_scale, + ).to(quant_scales_dtype) + quantized_scales.append(q_scales) + level_2_scales.append(max_scale) + quantized_scales = torch.cat(quantized_scales) # [level 1, scale is uint4] + level_2_scales = torch.cat(level_2_scales) # [level 2, scale is fp32] + + # Reformat Linear weight layout(OI) to Conv2d layout(HWIO,H=1,W=1) + weight_int4 = ( + weight_int4.T.contiguous() + .view(1, 1, self.in_features, self.out_features) + .contiguous() + ) del self.weight - self.register_buffer("weight", w_hwio) - self.register_buffer("scale1", s1_hwio) - self.register_buffer("scale2", s2_hwio) + self.register_buffer("weight", weight_int4) + self.register_buffer("scale1", quantized_scales.flatten()) + self.register_buffer("scale2", level_2_scales.flatten()) del self.weight_quant - self.deploy_mode = True print( - f"[{self.__class__.__name__}] Converted to HWIO.\n" - f" Weight: {self.weight.shape}\n" - f" Scale1: {self.scale1.shape} (Blocks, Out)\n" - f" Scale2: {self.scale2.shape} (1, Out)" + f"[{self.__class__.__name__}] Converted to HWIO. Weight: {self.weight.shape}", + f"Scale1(uint4): {self.scale1.shape}", + f"Scale2(fp32): {self.scale2.shape}", ) diff --git a/pymllm/backends/qualcomm/transformers/core/rms_norm.py b/pymllm/backends/qualcomm/transformers/core/rms_norm.py index e501031d..e55fe079 100644 --- a/pymllm/backends/qualcomm/transformers/core/rms_norm.py +++ b/pymllm/backends/qualcomm/transformers/core/rms_norm.py @@ -119,9 +119,12 @@ def freeze_weight(self): f"Class: {class_name}, Instance: {instance_class_name}, Weight Quantized: scale={self.weight_fake_quant.scale}, zp={self.weight_fake_quant.zero_point}" ) - def disable_quant(self): + def disable_fakequant(self): """Completely turn off quantization noise and return to floating point mode""" - self.weight_fake_quant.disable_fakequant() + self.weight_fake_quant.disable_fake_quant() + + def enable_fakequant(self): + self.weight_fake_quant.enable_fake_quant() def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index add7ac44..2dabf5c9 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -65,13 +65,13 @@ def __init__(self, config): self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = QLinearLPBQ( - self.hidden_size, self.intermediate_size, bias=False, block_size=32 + self.hidden_size, self.intermediate_size, bias=False, block_size=16 ) self.up_proj = QLinearLPBQ( - self.hidden_size, self.intermediate_size, bias=False, block_size=32 + self.hidden_size, self.intermediate_size, bias=False, block_size=16 ) self.down_proj = QLinearLPBQ( - self.intermediate_size, self.hidden_size, bias=False, block_size=32 + self.intermediate_size, self.hidden_size, bias=False, block_size=16 ) # QDQ @@ -173,25 +173,25 @@ def __init__(self, config: Qwen3Config, layer_idx: int): config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, - block_size=32, + block_size=16, ) self.k_proj = QLinearLPBQ( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, - block_size=32, + block_size=16, ) self.v_proj = QLinearLPBQ( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, - block_size=32, + block_size=16, ) self.o_proj = QLinearLPBQ( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, - block_size=32, + block_size=16, ) self.q_norm = QRMSNorm( self.head_dim, eps=config.rms_norm_eps, quant_bits=16 @@ -699,7 +699,7 @@ def __init__(self, config): self.model = Qwen3Model(config) self.vocab_size = config.vocab_size self.lm_head = QLinearLPBQ( - config.hidden_size, config.vocab_size, bias=False, block_size=32 + config.hidden_size, config.vocab_size, bias=False, block_size=16 ) self.mllm_qualcomm_max_length = None diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 45444232..02ea6a5f 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -166,11 +166,23 @@ def enable_qdq_observer(m): def enable_fake_quant(m): if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): m.enable_fakequant() + if isinstance(m, QLinearLPBQ): + m.enable_fakequant() + if isinstance(m, QRMSNorm): + m.enable_fakequant() + if isinstance(m, QEmbedding): + m.enable_fakequant() def disable_fake_quant(m): if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): m.disable_fakequant() + if isinstance(m, QLinearLPBQ): + m.disable_fakequant() + if isinstance(m, QRMSNorm): + m.disable_fakequant() + if isinstance(m, QEmbedding): + m.disable_fakequant() def convert_weight(m):