From 90658839b784d711c578414892480db42a319648 Mon Sep 17 00:00:00 2001 From: Super User Date: Thu, 16 Apr 2026 11:50:10 +0800 Subject: [PATCH] feat(qat): add lwc and distillation loss --- angelslim/compressor/qat/modules/quantizer.py | 176 ++++++++++++++-- .../compressor/qat/plugins/learnable_scale.py | 22 +- .../qat/trainers/end2end_trainer.py | 189 +++++++++++++++++- angelslim/utils/config_parser.py | 4 + ...qwen3-4b_w4a8_fp8_end2end_learn_scale.yaml | 4 + ...3-4b_w4a8_fp8_end2end_learn_scale_lwc.yaml | 75 +++++++ ...8_end2end_learn_scale_lwc_qkv_fp8attn.yaml | 91 +++++++++ ...8_fp8_end2end_learn_scale_qkv_fp8attn.yaml | 10 +- docs/source/features/quantization/qat.md | 75 ++++++- 9 files changed, 606 insertions(+), 40 deletions(-) create mode 100644 configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc.yaml create mode 100644 configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc_qkv_fp8attn.yaml diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py index 3bc215d0..bc81e998 100644 --- a/angelslim/compressor/qat/modules/quantizer.py +++ b/angelslim/compressor/qat/modules/quantizer.py @@ -69,6 +69,7 @@ def __init__(self, config, quant_info, x=None, is_act=False, resume=False, num_h self._apply_settings(info, rewrite_conf) self._set_quant_range() self._init_quant_params(x) + self._init_lwc_params(x, config) def _apply_settings(self, info, rewrite_conf): if rewrite_conf: @@ -130,6 +131,39 @@ def _init_quant_params(self, x): ) self._set_quant_parameters(scale, zp.round()) + def _init_lwc_params(self, x, config): + lwc_cfg = config.get("lwc", {}) + if isinstance(lwc_cfg, dict): + self.lwc = (not self.is_act) and bool(lwc_cfg.get("enable_lwc", False)) + self.lwc_init_value = float(lwc_cfg.get("lwc_init_value", 4.0)) + else: + self.lwc = (not self.is_act) and bool(lwc_cfg) + self.lwc_init_value = 4.0 + + if self.lwc: + if x.dim() != 2: + x_for_shape = x.flatten(1) + else: + x_for_shape = x + out_dim, in_dim = x_for_shape.shape + if self.granularity == "per-group": + if not self.group_size or self.group_size <= 0: + raise ValueError("per-group quantization requires positive group_size.") + assert in_dim % self.group_size == 0 + n_groups = in_dim // self.group_size + dim1 = out_dim * n_groups + elif self.granularity == "per-channel": + dim1 = out_dim + else: + dim1 = 1 + + init = ( + torch.ones((dim1, 1), device=x.device, dtype=torch.float32) * self.lwc_init_value + ) + self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True) + self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True) + self.sigmoid = nn.Sigmoid() + def _compute_scales(self, x, granularity="per-tensor", group_size=-1): if granularity == "per-tensor": s = torch.clamp(torch.max(torch.abs(x.flatten())), min=1e-8) @@ -309,47 +343,145 @@ def _expand(t, target_shape): elif self.granularity == "per-token": # scale: [n_tokens, 1] -> [n_tokens, in_features] then reshape to x.shape - init_shape = x.shape - rx = x.reshape(-1, x.shape[-1]) - scale = _expand(scale, rx.shape).reshape(init_shape) - zero_point = ( - _expand(zero_point, rx.shape).reshape(init_shape) - if zero_point is not None - else None - ) + scale = _expand(scale, x.shape) + zero_point = _expand(zero_point, x.shape) if zero_point is not None else None + + elif self.granularity == "per-head": + if self.num_heads <= 0: + raise ValueError("num_heads must be set for per-head granularity.") + if x.shape[-1] % self.num_heads != 0: + raise ValueError( + f"last dim ({x.shape[-1]}) must be divisible by num_heads ({self.num_heads})" + ) + head_dim = x.shape[-1] // self.num_heads + head_shape = (*x.shape[:-1], self.num_heads, head_dim) + + def _expand_per_head(t): + if t is None: + return None + # Broadcast one scale per head across that head's contiguous feature slice. + view_shape = (1,) * (x.dim() - 1) + (self.num_heads, 1) + return t.reshape(view_shape).expand(head_shape).reshape(x.shape) + + scale = _expand_per_head(scale) + zero_point = _expand_per_head(zero_point) return scale, zero_point - def fake_quant(self, x): - scale = clamp_ste(self.scale, 1e-4, 1e4) - round_zero_point = ( - None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) - ) - scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x) + def _expand_scale_zp_lwc(self, scale, zero_point, x): + def _expand(t, target_shape): + if t is None: + return None + return t.expand(target_shape) + + if self.granularity == "per-channel": + scale = _expand(scale, x.shape) + zero_point = _expand(zero_point, x.shape) + + elif self.granularity == "per-group": + group_size = self.group_size + scale = scale.unsqueeze(-1).expand(*scale.shape, group_size).reshape(x.shape) + if zero_point is not None: + zero_point = ( + zero_point.unsqueeze(-1).expand(*zero_point.shape, group_size).reshape(x.shape) + ) + + return scale, zero_point + + def _lwc_fake_quant_weight(self, x: torch.Tensor) -> torch.Tensor: + # Weight-only LWC path (OmniQuant-style): + # compute scale/zp from (possibly grouped) xmin/xmax + # with learnable bound factors, then quantize with STE. + x_dtype = x.dtype + x_work = x + if x_work.dim() != 2: + x_work = x_work.flatten(1) + + if self.granularity == "per-group": + out_dim, in_dim = x_work.shape + x_reduce = x_work.reshape(out_dim, in_dim // self.group_size, self.group_size) + xmin = x_reduce.amin(dim=-1) + xmax = x_reduce.amax(dim=-1) + elif self.granularity == "per-channel": + xmin = x_work.amin(dim=-1, keepdim=True) + xmax = x_work.amax(dim=-1, keepdim=True) + else: + # per-tensor (default) + xmin = x_work.amin().view(1, 1) + xmax = x_work.amax().view(1, 1) + + xmax = self.sigmoid(self.clip_factor_w_max).reshape_as(xmax) * xmax + xmin = self.sigmoid(self.clip_factor_w_min).reshape_as(xmin) * xmin + if self.is_sym: + abs_max = torch.max(xmax.abs(), xmin.abs()) + scale = (abs_max / self.qmax).to(dtype=x_work.dtype) + round_zero_point = None + else: + range_ = xmax - xmin + scale = (range_ / (self.qmax - self.qmin)).to(dtype=x_work.dtype) + zero_point = (-xmin) / range_ * (self.qmax - self.qmin) + self.qmin + round_zero_point = clamp_ste(round_ste(zero_point), self.qmin, self.qmax) + scale, round_zero_point = self._expand_scale_zp_lwc(scale, round_zero_point, x_work) + x_dequant = self._fake_quant_with_params(x_work, scale, round_zero_point) + return x_dequant.to(dtype=x_dtype).reshape(x.shape) + + def _w4a8_fp8_ste_from_dequant( + self, x_dequant: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: + fp8_scale = scale.max() * self.qmax / FP8_E4M3_QMAX + weight_fp8 = x_dequant / fp8_scale + weight_fp8_q = weight_fp8.clamp(FP8_E4M3_QMIN, FP8_E4M3_QMAX).to(torch.float8_e4m3fn) + weight_fp8_q = (weight_fp8_q.to(torch.bfloat16) - weight_fp8).detach() + weight_fp8 + return weight_fp8_q * fp8_scale + + def _fake_quant_with_params( + self, + x: torch.Tensor, + scale: torch.Tensor, + round_zero_point: torch.Tensor | None, + ) -> torch.Tensor: + scale = clamp_ste(scale, 1e-4, 1e4) + if round_zero_point is not None: + round_zero_point = clamp_ste(round_zero_point, self.qmin, self.qmax) if self.is_w4a8_fp8: x_int4 = round_ste(x / scale) - x_int4 = clamp_ste(x_int4, self.qmin, self.qmax).mul(scale) - fp8_scale = scale.max() * self.qmax / FP8_E4M3_QMAX - weight_fp8 = (x_int4 / fp8_scale).clamp(-448, 448).to(torch.float8_e4m3fn) - return weight_fp8.to(torch.bfloat16) * fp8_scale + x_int4 = clamp_ste(x_int4, self.qmin, self.qmax) + x_dequant = x_int4.mul(scale) + return self._w4a8_fp8_ste_from_dequant(x_dequant, scale) if self.dtype == "fp8": - weight_fp8 = (x / scale).clamp(-448, 448).to(torch.float8_e4m3fn) - return weight_fp8.to(torch.bfloat16) * scale + weight_fp8 = x / scale + weight_fp8 = clamp_ste(weight_fp8, FP8_E4M3_QMIN, FP8_E4M3_QMAX) + weight_fp8_q = weight_fp8.to(torch.float8_e4m3fn).to(torch.bfloat16) + weight_fp8_q = (weight_fp8_q - weight_fp8).detach() + weight_fp8 + return weight_fp8_q * scale x_int = round_ste(x / scale) if round_zero_point is not None: x_int = x_int.add(round_zero_point) x_int = clamp_ste(x_int, self.qmin, self.qmax) + x_dequant = x_int if round_zero_point is not None: - x_int = x_int.sub(round_zero_point) - return x_int.mul(scale) + x_dequant = x_dequant.sub(round_zero_point) + x_dequant = x_dequant.mul(scale) + return x_dequant + + def fake_quant(self, x): + scale = clamp_ste(self.scale, 1e-4, 1e4) + round_zero_point = ( + None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) + ) + scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x) + return self._fake_quant_with_params(x, scale, round_zero_point) def forward(self, x: torch.Tensor): if self.bits >= 16: return x + if self.lwc: + return self._lwc_fake_quant_weight(x) + if self.is_act and not self.dynamic and not self.init: self._lazy_init(x) return x diff --git a/angelslim/compressor/qat/plugins/learnable_scale.py b/angelslim/compressor/qat/plugins/learnable_scale.py index d79a420d..8f01514f 100644 --- a/angelslim/compressor/qat/plugins/learnable_scale.py +++ b/angelslim/compressor/qat/plugins/learnable_scale.py @@ -40,8 +40,9 @@ def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, ** # Parse learnable config (boolean switches for each parameter group) learnable_cfg = self.config.get("learnable", {}) - self.learn_act_scale = learnable_cfg.get("act_scale", True) - self.learn_weight_scale = learnable_cfg.get("weight_scale", False) + self.learn_act_scale = learnable_cfg.get("act_scale", False) + self.learn_weight_scale = learnable_cfg.get("weight_scale", True) + self.learn_lwc = learnable_cfg.get("lwc", False) self.learn_kv_scale = learnable_cfg.get("kv_scale", False) self.learn_norm = learnable_cfg.get("norm", False) @@ -100,6 +101,7 @@ def _apply_learn_strategy(self): act_scale=self.learn_act_scale, weight_scale=self.learn_weight_scale, kv_scale=self.learn_kv_scale, + lwc=self.learn_lwc, ) if self.learn_norm: @@ -109,7 +111,8 @@ def _apply_learn_strategy(self): f"act_scale={self.learn_act_scale}, " f"weight_scale={self.learn_weight_scale}, " f"kv_scale={self.learn_kv_scale}, " - f"norm={self.learn_norm}" + f"norm={self.learn_norm}", + f"lwc={self.learn_lwc}", ) print_info( f"Learnable config ({learnable_summary}): " @@ -205,7 +208,7 @@ def quant_parameters(model): def set_weight_parameters(model, requires_grad): params = [] for n, m in model.named_parameters(): - if n.find("weight") > -1 and not (n.find("scale") > -1 or n.find("zero_point") > -1): + if n.endswith("weight") and not (n.find("scale") > -1 or n.find("zero_point") > -1): m.requires_grad = requires_grad return iter(params) @@ -213,7 +216,7 @@ def set_weight_parameters(model, requires_grad): def weight_parameters(model): params = [] for n, m in model.named_parameters(): - if n.find("weight") > -1 and not (n.find("scale") > -1 or n.find("zero_point") > -1): + if n.endswith("weight") and not (n.find("scale") > -1 or n.find("zero_point") > -1): params.append(m) return iter(params) @@ -226,7 +229,9 @@ def trainable_parameters(model): return iter(params) -def _set_learnable_parameters(model, act_scale=False, weight_scale=False, kv_scale=False): +def _set_learnable_parameters( + model, act_scale=False, weight_scale=False, kv_scale=False, lwc=False +): _KV_SUFFIXES = ("k_proj", "v_proj") for name, module in model.named_modules(): @@ -243,6 +248,11 @@ def _set_learnable_parameters(model, act_scale=False, weight_scale=False, kv_sca if "scale" in pname or "zero_point" in pname: param.requires_grad = True + if lwc and hasattr(module, "weight_quantizer"): + for pname, param in module.weight_quantizer.named_parameters(): + if "clip_factor_w_max" in pname or "clip_factor_w_min" in pname: + param.requires_grad = True + if kv_scale and hasattr(module, "qkv_quantizer"): suffix = name.rsplit(".", 1)[-1] if "." in name else name if suffix in _KV_SUFFIXES: diff --git a/angelslim/compressor/qat/trainers/end2end_trainer.py b/angelslim/compressor/qat/trainers/end2end_trainer.py index 5e36d6b0..fb5daf6d 100644 --- a/angelslim/compressor/qat/trainers/end2end_trainer.py +++ b/angelslim/compressor/qat/trainers/end2end_trainer.py @@ -13,14 +13,130 @@ # limitations under the License. import torch +import torch.nn.functional as F from datasets import load_dataset from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments from ....data.qat_dataset import QATDataset from ....utils import print_info +from ..plugins.learnable_scale import set_quant_state from .trainer_factory import TrainerFactory +class QATSeq2SeqTrainer(Seq2SeqTrainer): + def __init__(self, *args, loss_config=None, quant_config=None, **kwargs): + super().__init__(*args, **kwargs) + loss_config = loss_config or {} + quant_config = quant_config or {} + self.loss_type = str(loss_config.get("loss_type", "origin")).lower() + self.loss_topk = loss_config.get("loss_topk") + self.kd_temperature = float(loss_config.get("kd_temperature", 1.0)) + self.kd_alpha = float(loss_config.get("kd_alpha", 0.5)) + self.use_weight_quant = quant_config.get("use_weight_quant", False) + self.use_activation_quant = quant_config.get("use_activation_quant", False) + self.use_qkv_quant = quant_config.get("use_qkv_quant", False) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if self.loss_type == "origin": + return super().compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) + + teacher_logits = self.get_ori_outputs(model, inputs).logits + student_inputs = dict(inputs) + if self.loss_type != "kd": + student_inputs.pop("labels", None) + outputs = model(**student_inputs) + student_logits = outputs.logits + + if self.loss_type == "kl": + loss = F.kl_div( + F.log_softmax(student_logits.flatten(0, -2), dim=-1), + F.softmax(teacher_logits.flatten(0, -2), dim=-1), + reduction="batchmean", + ) + elif self.loss_type == "rkl": + loss = F.kl_div( + F.log_softmax(teacher_logits.flatten(0, -2), dim=-1), + F.softmax(student_logits.flatten(0, -2), dim=-1), + reduction="batchmean", + ) + elif self.loss_type == "mse": + loss = F.mse_loss(student_logits, teacher_logits) + elif self.loss_type == "kd": + if getattr(outputs, "loss", None) is None: + raise ValueError("loss_type='kd' requires labels to compute CE loss.") + temperature = max(self.kd_temperature, 1e-6) + alpha = self.kd_alpha + distill_loss = F.kl_div( + F.log_softmax(student_logits.flatten(0, -2) / temperature, dim=-1), + F.softmax(teacher_logits.flatten(0, -2) / temperature, dim=-1), + reduction="batchmean", + ) + loss = outputs.loss * (1 - alpha) + distill_loss * (alpha * temperature * temperature) + elif self._is_reverse_topk_loss(self.loss_type): + topk = self._resolve_topk(self.loss_type, student_logits.size(-1)) + top_student_logits, indices = student_logits.topk(topk, dim=-1, sorted=False) + top_teacher_logits = teacher_logits.gather(-1, indices) + loss = F.kl_div( + F.log_softmax(top_teacher_logits.flatten(0, -2), dim=-1), + F.softmax(top_student_logits.flatten(0, -2), dim=-1), + reduction="batchmean", + ) + elif self._is_forward_topk_loss(self.loss_type): + topk = self._resolve_topk(self.loss_type, teacher_logits.size(-1)) + top_teacher_logits, indices = teacher_logits.topk(topk, dim=-1, sorted=False) + top_student_logits = student_logits.gather(-1, indices) + loss = F.kl_div( + F.log_softmax(top_student_logits.flatten(0, -2), dim=-1), + F.softmax(top_teacher_logits.flatten(0, -2), dim=-1), + reduction="batchmean", + ) + else: + raise ValueError(f"Unsupported QAT loss_type: {self.loss_type}") + + return (loss, outputs) if return_outputs else loss + + @staticmethod + def _is_forward_topk_loss(loss_type): + return loss_type.startswith("kl_top") + + @staticmethod + def _is_reverse_topk_loss(loss_type): + return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top") + + def _resolve_topk(self, loss_type, vocab_size): + topk = self.loss_topk + if topk is None and "_top_" in loss_type: + topk = int(loss_type.rsplit("_", 1)[-1]) + if topk is None: + topk = 1000 + topk = int(topk) + if topk <= 0: + raise ValueError(f"loss_topk must be positive, got: {topk}") + return min(topk, vocab_size) + + @torch.no_grad() + def get_ori_outputs(self, model, inputs): + teacher_inputs = dict(inputs) + teacher_inputs.pop("labels", None) + raw_model = self.accelerator.unwrap_model(model) + set_quant_state(raw_model, weight_quant=False, act_quant=False, qkv_quant=False) + try: + outputs = model(**teacher_inputs) + finally: + set_quant_state( + raw_model, + weight_quant=self.use_weight_quant, + act_quant=self.use_activation_quant, + qkv_quant=self.use_qkv_quant, + ) + return outputs + + @TrainerFactory.register("end2end") class End2EndTrainer: @@ -36,11 +152,77 @@ def __init__(self, quant_model, config, plugin_manager): self.do_train = config["compress_config"].QAT.do_train self.external_trainer = None + self.loss_config = { + "loss_type": config["compress_config"].QAT.loss_type, + "loss_topk": config["compress_config"].QAT.loss_topk, + "kd_temperature": config["compress_config"].QAT.kd_temperature, + "kd_alpha": config["compress_config"].QAT.kd_alpha, + } + self.quant_config = { + "use_weight_quant": config["compress_config"] + .QAT.plugin_config.get("quant_config", {}) + .get("use_weight_quant", False), + "use_activation_quant": config["compress_config"] + .QAT.plugin_config.get("quant_config", {}) + .get("use_activation_quant", False), + "use_qkv_quant": config["compress_config"] + .QAT.plugin_config.get("quant_config", {}) + .get("use_qkv_quant", False), + } + + def _init_optimizer(self): + lr = float(self.config["compress_config"].QAT.hf_args.get("learning_rate", 1e-5)) + wd = float(self.config["compress_config"].QAT.hf_args.get("weight_decay", 0)) + params = [ + { + "params": [ + p + for n, p in self.quant_model.model.named_parameters() + if "scale" in n or "zero_point" in n + ], + "weight_decay": wd, + "lr": lr, + } + ] + + enable_lwc = ( + self.config["compress_config"] + .QAT.plugin_config.get("quant_config", {}) + .get("lwc", {}) + .get("enable_lwc", False) + ) + if enable_lwc: + lwc_lr = float( + self.config["compress_config"] + .QAT.plugin_config.get("quant_config", {}) + .get("lwc", {}) + .get("lwc_lr", 1e-1) + ) + lwc_params = [ + { + "params": [ + p + for n, p in self.quant_model.model.named_parameters() + if "clip_factor_w_max" in n or "clip_factor_w_min" in n + ], + "weight_decay": wd, + "lr": lwc_lr, + } + ] + params.extend(lwc_params) + + self.optimizer = torch.optim.AdamW(params) + if enable_lwc: + print_info(f"Init optimizer with learnable lr={lr} lwc_lr={lwc_lr} weight_decay={wd}") + else: + print_info(f"Init optimizer with learnable lr={lr} weight_decay={wd}") + def prepare_trainer(self): if self.training_mode == "blockwise": return if self.training_mode == "end2end" and self.dist_mode == "hf": - self.external_trainer = Seq2SeqTrainer( + self._init_optimizer() + self.external_trainer = QATSeq2SeqTrainer( model=self.quant_model.model, processing_class=self.quant_model.tokenizer, args=Seq2SeqTrainingArguments( @@ -49,6 +231,9 @@ def prepare_trainer(self): ), train_dataset=self.train_dataset, eval_dataset=None, + optimizers=(self.optimizer, None), + loss_config=self.loss_config, + quant_config=self.quant_config, ) else: raise NotImplementedError(f"Unsupported distribution mode: {self.dist_mode}") @@ -68,8 +253,8 @@ def prepare_dataset(self, dataloader): def run(self, dataloader): self.prepare_dataset(dataloader) - self.prepare_trainer() self.plugin_manager.call_before_train(train_dataset=self.train_dataset) + self.prepare_trainer() if self.resume_ckpt_dir is not None: print_info(f"Loading from resume {self.resume_ckpt_dir}") diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index c3cdd71e..77481b9e 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -272,6 +272,10 @@ class QATTrainingConfig: hf_dataset: Optional[str] = None do_train: bool = field(default=True) resume_ckpt_dir: Optional[str] = None + loss_type: str = field(default="origin") + loss_topk: Optional[int] = None + kd_temperature: float = field(default=1.0) + kd_alpha: float = field(default=0.5) hf_args: Dict[str, Any] = field(default_factory=dict) diff --git a/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale.yaml b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale.yaml index 52f1089c..91498a60 100644 --- a/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale.yaml +++ b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale.yaml @@ -28,6 +28,10 @@ compression: save_format: fake # Set "fake" or "real", or not set(skip save) do_train: true # resume_ckpt_dir: Resume for fake checkpoint, when save_format=='fake', or not set + loss_type: origin # origin | kl | rkl | mse | kd | kl_top[_K] | r_kl_top[_K] + loss_topk: null # optional, overrides the K parsed from loss_type + kd_temperature: 1.0 + kd_alpha: 0.5 plugin_config: enable_scale: true quant_config: diff --git a/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc.yaml b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc.yaml new file mode 100644 index 00000000..028ad851 --- /dev/null +++ b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc.yaml @@ -0,0 +1,75 @@ +global: + save_path: ./output + +model: + name: Qwen + model_path: Qwen/Qwen3-4B + trust_remote_code: true + torch_dtype: auto + device_map: auto + low_cpu_mem_usage: true + use_cache: false + +compression: + name: QAT + quantization: + name: w4a8_fp8 + bits: 8 + quant_method: + weight: per-group + activation: per-tensor + group_size: 128 + ignore_layers: ["lm_head", "embed_tokens"] # Skip quantization for these layers + QAT: + hf_dataset: Salesforce/wikitext,wikitext-2-raw-v1 # hf dataset name, will overwrite dataset.data_path + # hf_cache_dir: Specify your cache path, or not set(default) + training_mode: end2end + dist_mode: hf # "hf" is using HF Trainer + save_format: fake # Set "fake" or "real", or not set(skip save) + do_train: true + # resume_ckpt_dir: Resume for fake checkpoint, when save_format=='fake', or not set + loss_type: origin # origin | kl | rkl | mse | kd | kl_top[_K] | r_kl_top[_K] + loss_topk: null # optional, overrides the K parsed from loss_type + kd_temperature: 1.0 + kd_alpha: 0.5 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: false + lazy_init_samples: 10 + # --- Learnable parameter control --- + # Each switch independently controls whether that parameter group is trainable. + # Model weights themselves always stay frozen. + learnable: + act_scale: false # Activation quantizer scale/zero_point (default: false) + weight_scale: true # Weight quantizer scale/zero_point (default: true) + kv_scale: false # KV cache quantizer scale in k_proj/v_proj (default: false) + norm: false # Norm layer (RMSNorm/LayerNorm) weights (default: false) + lwc: true # LWC quantization (default: false) + # --- LWC quantization --- + lwc: + enable_lwc: true + lwc_init_value: 4.0 + lwc_lr: 0.5 # default 0.5, can be set to 0.2 + + hf_args: + # output_dir: Not to set, same as global.save_path + # other arguments, see https://huggingface.co/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 2 + gradient_accumulation_steps: 2 + learning_rate: 1e-5 + lr_scheduler_type: constant + num_train_epochs: 1 + save_strategy: 'no' + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 2048 + num_samples: 256 + batch_size: 1 + diff --git a/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc_qkv_fp8attn.yaml b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc_qkv_fp8attn.yaml new file mode 100644 index 00000000..7b0b6ec0 --- /dev/null +++ b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_lwc_qkv_fp8attn.yaml @@ -0,0 +1,91 @@ +global: + save_path: ./output + +model: + name: Qwen + model_path: Qwen/Qwen3-4B + trust_remote_code: true + torch_dtype: auto + device_map: auto + low_cpu_mem_usage: true + use_cache: false + attn_implementation: eager # Required when fp8_attn is enabled + +compression: + name: QAT + quantization: + name: w4a8_fp8 + bits: 8 + quant_method: + weight: per-group + activation: per-tensor + group_size: 128 + ignore_layers: ["lm_head", "embed_tokens"] # Skip quantization for these layers + QAT: + hf_dataset: Salesforce/wikitext,wikitext-2-raw-v1 # hf dataset name, will overwrite dataset.data_path + # hf_cache_dir: Specify your cache path, or not set(default) + training_mode: end2end + dist_mode: hf # "hf" is using HF Trainer + save_format: save_kvcache_only # "save_kvcache_only": only export KV cache scales; "real": save real-quant model; "real_and_kvcache": save real-quant model + KV cache scales; "fake": save fake-quant state_dict (non-distributed only); or not set(skip save) + do_train: true + # resume_ckpt_dir: Resume for fake checkpoint, when save_format=='fake', or not set + loss_type: origin # origin | kl | rkl | mse | kd | kl_top[_K] | r_kl_top[_K] + loss_topk: null # optional, overrides the K parsed from loss_type + kd_temperature: 1.0 + kd_alpha: 0.5 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: true + lazy_init_samples: 10 + # --- Learnable parameter control --- + # Each switch independently controls whether that parameter group is trainable. + # Model weights themselves always stay frozen. + learnable: + act_scale: false # Activation quantizer scale/zero_point (default: false) + weight_scale: true # Weight quantizer scale/zero_point (default: true) + kv_scale: true # KV cache quantizer scale in k_proj/v_proj (default: false) + norm: false # Norm layer (RMSNorm/LayerNorm) weights (default: false) + lwc: true # LWC quantization (default: false) + # --- LWC quantization --- + lwc: + enable_lwc: true + lwc_init_value: 4.0 + lwc_lr: 0.5 # default 0.5, can be set to 0.2 + # --- QKV / Attention FP8 quantization --- + fp8_attn: true # Enable FP8 attention simulation (cast attn_weights to FP8) + q: # Query projection output quantization + qtype: fp8 + granularity: per-token # Dynamic per-token quantization for Q + group_size: -1 + is_sym: true + k: # Key projection output quantization (KV cache compression) + qtype: fp8 + granularity: per-head # Per-head quantization, one scale per KV head + group_size: -1 + is_sym: true + v: # Value projection output quantization (KV cache compression) + qtype: fp8 + granularity: per-head + group_size: -1 + is_sym: true + hf_args: + # output_dir: Not to set, same as global.save_path + # other arguments, see https://huggingface.co/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 2 + gradient_accumulation_steps: 2 + learning_rate: 5e-5 + lr_scheduler_type: constant + num_train_epochs: 1 + save_strategy: 'no' + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 2048 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_qkv_fp8attn.yaml b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_qkv_fp8attn.yaml index bd9df1d2..31c35748 100644 --- a/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_qkv_fp8attn.yaml +++ b/configs/qwen3/qat/w4a8_fp8/learn_scale/qwen3-4b_w4a8_fp8_end2end_learn_scale_qkv_fp8attn.yaml @@ -9,6 +9,7 @@ model: device_map: auto low_cpu_mem_usage: true use_cache: false + attn_implementation: eager # Required when fp8_attn is enabled compression: name: QAT @@ -28,18 +29,23 @@ compression: save_format: save_kvcache_only # "save_kvcache_only": only export KV cache scales; "real": save real-quant model; "real_and_kvcache": save real-quant model + KV cache scales; "fake": save fake-quant state_dict (non-distributed only); or not set(skip save) do_train: true # resume_ckpt_dir: Resume for fake checkpoint, when save_format=='fake', or not set + loss_type: origin # origin | kl | rkl | mse | kd | kl_top[_K] | r_kl_top[_K] + loss_topk: null # optional, overrides the K parsed from loss_type + kd_temperature: 1.0 + kd_alpha: 0.5 plugin_config: enable_scale: true quant_config: use_weight_quant: true use_activation_quant: true + use_qkv_quant: true lazy_init_samples: 10 # --- Learnable parameter control --- # Each switch independently controls whether that parameter group is trainable. # Model weights themselves always stay frozen. learnable: - act_scale: true # Activation quantizer scale/zero_point (default: true) - weight_scale: false # Weight quantizer scale/zero_point (default: false) + act_scale: false # Activation quantizer scale/zero_point (default: false) + weight_scale: true # Weight quantizer scale/zero_point (default: true) kv_scale: false # KV cache quantizer scale in k_proj/v_proj (default: false) norm: false # Norm layer (RMSNorm/LayerNorm) weights (default: false) # --- QKV / Attention FP8 quantization --- diff --git a/docs/source/features/quantization/qat.md b/docs/source/features/quantization/qat.md index b002c1f0..cefb3395 100644 --- a/docs/source/features/quantization/qat.md +++ b/docs/source/features/quantization/qat.md @@ -199,10 +199,16 @@ dataset: lazy_init_samples: 60 # activation 校准所需的样本数(默认 10) # ========== 可学习参数控制(可选) ========== learnable: - act_scale: true # 激活量化器的 scale/zero_point(默认:true) - weight_scale: false # 权重量化器的 scale/zero_point(默认:false) + act_scale: false # 激活量化器的 scale/zero_point(默认:false) + weight_scale: true # 权重量化器的 scale/zero_point(默认:true) kv_scale: false # KV Cache 量化器的 scale(默认:false) norm: false # Norm 层权重(默认:false) + lwc: false # LWC 裁剪参数 clip_factor_w_max/min(默认:false) + # ========== LWC 配置(可选) ========== + lwc: + enable_lwc: true # 是否真正启用 LWC;关闭时不会创建 clip 参数 + lwc_init_value: 4.0 # clip_factor_w_max/min 的初始化值 + lwc_lr: 0.5 # 仅 end2end 模式生效:LWC 参数独立学习率 # 可选:覆盖 compression.quantization 中的默认量化配置 weight: qtype: int8 # 权重量化类型(如 int4, int8, fp8) @@ -248,11 +254,31 @@ compression: QAT: training_mode: "end2end" dist_mode: hf + loss_type: origin # origin | kl | rkl | mse | kd | kl_top[_K] | r_kl_top[_K] + loss_topk: null # 仅 top-k KL loss 使用;可覆盖 loss_type 中内联的 K + kd_temperature: 1.0 # 仅 loss_type=kd 时生效,必须 > 0 + kd_alpha: 0.5 # 仅 loss_type=kd 时生效,必须在 [0, 1] hf_args: # output_dir: /path/to/output 训练输出目录,不需要再指定,同 global.save_path # 其余参数同 HF 的 TrainingArguments ``` +其中 loss 相关字段的含义如下: + +| 配置项 | 类型 | 默认值 | 描述 | +|--------|------|--------|------| +| `loss_type` | str | `origin` | 训练目标类型。`origin` 使用 HF 原生监督 loss;`kl` / `rkl` / `mse` / `kd` 使用 teacher-student logits 对齐;`kl_top[_K]` / `r_kl_top[_K]` 仅在 top-k token 上计算 KL | +| `loss_topk` | int / null | `null` | 仅用于 `kl_top[_K]` / `r_kl_top[_K]`。若 `loss_type` 中未写内联 `_K`,则必须显式提供该字段 | +| `kd_temperature` | float | `1.0` | 仅用于 `loss_type: kd`,必须大于 0 | +| `kd_alpha` | float | `0.5` | 仅用于 `loss_type: kd`,表示 CE loss 与 distillation loss 的混合系数,必须落在 `[0, 1]` | + +配置约束如下: + +- `loss_type` 只能是 `origin`、`kl`、`rkl`、`mse`、`kd`、`kl_top[_K]`、`r_kl_top[_K]` +- `loss_topk` 只能与 top-k KL loss 搭配使用,且必须为正整数 +- `kd_temperature` 和 `kd_alpha` 只允许在 `loss_type: kd` 时设置为非默认值 +- 这些约束会在 YAML 解析阶段校验,避免训练运行到一半才报错 + #### Blockwise 训练专属配置 逐块训练将模型按 Transformer Block 逐层训练,显存占用更低。逐块模式下 QAT 配置同样位于 `compression.QAT` 下: @@ -297,6 +323,7 @@ global: - 使用 **STE(Straight-Through Estimator)** 使 `round` 和 `clamp` 操作可微分,允许梯度通过量化操作反向传播 - 支持 **延迟初始化(Lazy Initialization)**:对于静态 activation 量化,通过前 N 个样本校准确定 scale - `scale` 和 `zero_point` 注册为 `nn.Parameter`,在训练过程中可学习优化 +- 支持 **LWC(Learnable Weight Clipping)**:为权重量化器引入可学习的裁剪因子 `clip_factor_w_max/min` **伪量化过程**(INT 类型为例): @@ -336,22 +363,53 @@ plugin_config: enable_scale: true quant_config: learnable: - act_scale: true # 激活量化器的 scale/zero_point(默认:true) - weight_scale: false # 权重量化器的 scale/zero_point(默认:false) + act_scale: false # 激活量化器的 scale/zero_point(默认:false) + weight_scale: true # 权重量化器的 scale/zero_point(默认:true) kv_scale: false # KV Cache 量化器的 scale(k_proj/v_proj 中的 qkv_quantizer)(默认:false) norm: false # Norm 层(RMSNorm / LayerNorm)的权重(默认:false) + lwc: false # LWC 裁剪参数 clip_factor_w_max/min(默认:false) ``` | 配置项 | 类型 | 默认值 | 描述 | |--------|------|--------|------| -| `act_scale` | bool | `true` | 是否学习激活量化器(`act_quantizer`)的 scale / zero_point | -| `weight_scale` | bool | `false` | 是否学习权重量化器(`weight_quantizer`)的 scale / zero_point | +| `act_scale` | bool | `false` | 是否学习激活量化器(`act_quantizer`)的 scale / zero_point | +| `weight_scale` | bool | `true` | 是否学习权重量化器(`weight_quantizer`)的 scale / zero_point | | `kv_scale` | bool | `false` | 是否学习 KV Cache 量化器(`qkv_quantizer`)的 scale(仅 k_proj / v_proj) | | `norm` | bool | `false` | 是否学习 Norm 层(如 `input_layernorm`、`post_attention_layernorm`)的 weight 参数 | +| `lwc` | bool | `false` | 是否学习 LWC 的裁剪参数 `clip_factor_w_max` / `clip_factor_w_min` | 各开关可以自由组合,例如同时开启 `weight_scale` 和 `kv_scale` 来联合优化权重量化和 KV Cache 量化的 scale 参数。 -> **注意**:如果未提供 `learnable` 配置,默认行为等价于 `act_scale: true`(其余为 `false`),即仅学习激活量化器的 scale 参数。 +> **注意**:如果未提供 `learnable` 配置,默认行为等价于 `act_scale: false`、`weight_scale: true`(其余为 `false`),即默认仅学习权重量化器的 scale 参数。 + +#### LWC 配置 + +LWC(Learnable Weight Clipping)用于在权重量化前引入可学习裁剪范围,帮助降低异常值对量化误差的影响。相关配置位于 `compression.QAT.plugin_config.quant_config.lwc`: + +```yaml +plugin_config: + enable_scale: true + quant_config: + learnable: + lwc: true + lwc: + enable_lwc: true + lwc_init_value: 4.0 + lwc_lr: 0.5 +``` + +| 配置项 | 类型 | 默认值 | 描述 | +|--------|------|--------|------| +| `enable_lwc` | bool | `false` | 是否真正启用 LWC。关闭时不会创建 `clip_factor_w_max/min`,即使存在 `lwc` 字段也不会生效 | +| `lwc_init_value` | float | `4.0` | `clip_factor_w_max/min` 的初始化值 | +| `lwc_lr` | float | `0.5` | 仅 end-to-end 模式下生效,作为 LWC 参数组的独立学习率 | + +`learnable.lwc` 与 `lwc.enable_lwc` 的职责不同: + +- `lwc.enable_lwc` 决定是否创建并启用 LWC 功能 +- `learnable.lwc` 决定已创建的 `clip_factor_w_max/min` 是否参与训练 + +通常只有在两者都为 `true` 时,LWC 才会既生效又可学习。 ### TrainerFactory — 训练器工厂 @@ -365,8 +423,9 @@ plugin_config: ### End-to-End 训练器 - 使用 HuggingFace `Seq2SeqTrainer` 进行训练 -- 使用 `AdamW` 优化器,仅针对 `scale` 和 `zero_point` 参数(默认学习率 1e-5) +- 使用 `AdamW` 优化器,默认优化 `scale` 和 `zero_point` 参数;若启用 LWC,则 `clip_factor_w_max/min` 会作为独立参数组使用 `lwc_lr` - 支持 HuggingFace 生态的各种训练参数(学习率调度、梯度累积等) +- 支持 `origin`、`kl`、`rkl`、`mse`、`kd`、`kl_top[_K]`、`r_kl_top[_K]` 等 loss 配置 - 完整执行流程:`prepare_dataset` → `prepare_trainer` → `call_before_train` → 可选 `resume` → `train` → `call_after_train` ### Blockwise 训练器