diff --git a/.gitignore b/.gitignore index 00e37b2b..d8fcb3d2 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ eval/ *_ckpt*/ output/ outputs/ +output*/ +logs*/ outs/ wandb/ tools/results/ diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py index bc81e998..4eb627e4 100644 --- a/angelslim/compressor/qat/modules/quantizer.py +++ b/angelslim/compressor/qat/modules/quantizer.py @@ -18,6 +18,8 @@ import torch.nn as nn import torch.nn.functional as F +from ....utils import is_deepspeed_zero3_enabled, is_zero3_param + FP8_E4M3_QMIN = -448 FP8_E4M3_QMAX = 448 @@ -49,10 +51,30 @@ def _parse_bits_and_dtype(qtype_str): class Quantizer(nn.Module): - def __init__(self, config, quant_info, x=None, is_act=False, resume=False, num_heads=-1): + def __init__( + self, + config, + quant_info, + x=None, + is_act=False, + resume=False, + num_heads=-1, + weight_shape=None, + ): super().__init__() self.is_act = is_act self.num_heads = num_heads + # ``weight_shape`` lets the caller pre-declare the (out_features, + # in_features) of the parent Linear so we can size weight-side + # quantizer Parameters without ever touching the (possibly ZeRO-3 + # sharded) weight tensor. + self.weight_shape = ( + (int(weight_shape[0]), int(weight_shape[1])) if weight_shape is not None else None + ) + # Configurable initial values used when ZeRO-3 is active and we + # cannot depend on the weight data. + self.weight_scale_init_value = float(config.get("weight_scale_init_value", 1.0)) + self.activation_scale_init_value = float(config.get("activation_scale_init_value", 1.0)) info = quant_info.quant_algo_info["w"] self.group_size = quant_info.quant_algo_info.get("w_group_size", -1) rewrite_conf = config.get("weight", {}) @@ -117,8 +139,21 @@ def _init_quant_params(self, x): self.scale = self.zero_point = None if self.resume: self.init = True - zp = torch.empty(1) if not self.is_sym else None - self._set_quant_parameters(torch.empty(1), zp) + init_val = self.activation_scale_init_value + scale = torch.full((1,), init_val, dtype=torch.float32) + zp = torch.zeros(1, dtype=torch.float32) if not self.is_sym else None + self._set_quant_parameters(scale, zp) + return + + # Weight-side path. If we cannot use ``x`` (ZeRO-3 sharded, + # meta, or simply not provided), allocate Parameters by shape + # and ``weight_scale_init_value``. + if self._needs_external_weight_init(x): + shape = self._weight_scale_shape_from_meta() + init_val = self.weight_scale_init_value + scale = torch.full(shape, init_val, dtype=torch.float32) + zp = torch.zeros(shape, dtype=torch.float32) if not self.is_sym else None + self._set_quant_parameters(scale, zp) return if self.is_sym: @@ -131,6 +166,52 @@ def _init_quant_params(self, x): ) self._set_quant_parameters(scale, zp.round()) + def _needs_external_weight_init(self, x): + """True when weight-side init must skip data-dependent computation + and instead allocate Parameters from shape + init_value. + + Triggered by: + * DeepSpeed ZeRO-3 active (HF integration registered) + * ``x`` is a ZeRO-3 sharded Parameter + * ``x`` is None / on meta device / empty + """ + if is_deepspeed_zero3_enabled(): + return True + if x is None: + return True + if is_zero3_param(x): + return True + if hasattr(x, "device") and x.device.type == "meta": + return True + if hasattr(x, "numel") and x.numel() == 0: + return True + return False + + def _weight_2d_shape(self): + """Resolve (out_features, in_features) for the underlying Linear. + Callers must have passed ``weight_shape`` via ``QuantLinear``.""" + if self.weight_shape is not None: + return self.weight_shape + raise RuntimeError( + "Quantizer needs ``weight_shape`` to size weight scale without a " + "concrete tensor (set in QuantLinear.__init__)." + ) + + def _weight_scale_shape_from_meta(self): + out_dim, in_dim = self._weight_2d_shape() + if self.granularity == "per-channel": + return (out_dim, 1) + if self.granularity == "per-group": + if not self.group_size or self.group_size <= 0: + raise ValueError("per-group quantization requires positive group_size.") + if in_dim % self.group_size != 0: + raise ValueError( + f"dim 1 ({in_dim}) not divisible by group_size ({self.group_size})" + ) + return (out_dim, in_dim // self.group_size) + # per-tensor and any reduce-to-scalar variant + return (1,) + def _init_lwc_params(self, x, config): lwc_cfg = config.get("lwc", {}) if isinstance(lwc_cfg, dict): @@ -141,11 +222,18 @@ def _init_lwc_params(self, x, config): self.lwc_init_value = 4.0 if self.lwc: - if x.dim() != 2: - x_for_shape = x.flatten(1) + # Resolve (out_dim, in_dim) without depending on ``x`` data. + if self._needs_external_weight_init(x): + out_dim, in_dim = self._weight_2d_shape() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: - x_for_shape = x - out_dim, in_dim = x_for_shape.shape + if x.dim() != 2: + x_for_shape = x.flatten(1) + else: + x_for_shape = x + out_dim, in_dim = x_for_shape.shape + device = x.device + if self.granularity == "per-group": if not self.group_size or self.group_size <= 0: raise ValueError("per-group quantization requires positive group_size.") @@ -157,9 +245,7 @@ def _init_lwc_params(self, x, config): else: dim1 = 1 - init = ( - torch.ones((dim1, 1), device=x.device, dtype=torch.float32) * self.lwc_init_value - ) + init = torch.ones((dim1, 1), device=device, dtype=torch.float32) * self.lwc_init_value self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True) self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True) self.sigmoid = nn.Sigmoid() @@ -473,7 +559,14 @@ def fake_quant(self, x): None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) ) scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x) - return self._fake_quant_with_params(x, scale, round_zero_point) + out = self._fake_quant_with_params(x, scale, round_zero_point) + # Scale is kept in fp32 for numerical stability, but multiplying by + # a bf16/fp16 activation upcasts the result. Cast back to the input + # dtype so downstream F.linear / DeepSpeed autocast wrappers see a + # consistent dtype. + if out.dtype != x.dtype: + out = out.to(x.dtype) + return out def forward(self, x: torch.Tensor): if self.bits >= 16: @@ -516,8 +609,18 @@ def __init__( self.register_parameter("bias", org_module.bias) self.use_weight_quant = use_weight_quant self.use_act_quant = use_act_quant + # Under ZeRO-3 the weight Parameter ``org_module.weight`` may be a + # zero-numel shard. Pass an explicit (out, in) shape so the weight + # quantizer can size its scale Parameter from the Linear shape + # rather than inspecting the (possibly sharded) tensor. + weight_shape = (org_module.out_features, org_module.in_features) if self.use_weight_quant: - self.weight_quantizer = Quantizer(config, quant_info, x=org_module.weight) + self.weight_quantizer = Quantizer( + config, + quant_info, + x=org_module.weight, + weight_shape=weight_shape, + ) if self.use_act_quant: self.act_quantizer = Quantizer(config, quant_info, is_act=True, resume=resume) @@ -531,13 +634,16 @@ def __init__( ) def forward(self, input: torch.Tensor): - if input.shape[0] == 0: - return self.fwd_func(input, self.weight, self.bias) - weight = self.weight_quantizer(self.weight) if self.use_weight_quant else self.weight if self.use_act_quant: input = self.act_quantizer(input) - output = self.fwd_func(input, weight, self.bias) + # Defensive dtype alignment: upstream (DeepSpeed ZeRO-3 / HF + # autocast) may have cast ``input`` to fp16 even though we run in + # bf16. Align to the (fake-quantised) weight dtype so F.linear + # stays consistent. + output = self.fwd_func( + input.to(self.weight.dtype), weight.to(self.weight.dtype), self.bias + ) if self.use_qkv_quant: output = self.qkv_quantizer(output) return output diff --git a/angelslim/compressor/qat/plugins/distill_loss.py b/angelslim/compressor/qat/plugins/distill_loss.py new file mode 100644 index 00000000..4024b3e4 --- /dev/null +++ b/angelslim/compressor/qat/plugins/distill_loss.py @@ -0,0 +1,117 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +class DistillLoss: + def __init__(self, loss_type="kl", loss_topk=None, kd_temperature=1.0): + self.loss_type = str(loss_type).lower() + self.loss_topk = loss_topk + self.kd_temperature = float(kd_temperature) + + @staticmethod + def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor: + """Per-token KL(tgt || src) (shape [N]).""" + return F.kl_div(log_p_src, p_tgt, reduction="none").sum(dim=-1) + + def compute(self, student_logits, teacher_logits, labels): + """Return per-token KD losses computed only on valid labels. + + The returned dict always contains ``loss``, ``forward_kl`` and + ``backward_kl`` so callers can log diagnostics for every KD variant. + """ + flat_mask = (labels != -100).reshape(-1) + if flat_mask.sum() == 0: + zero = student_logits.new_zeros(()) + return {"loss": zero, "forward_kl": zero, "backward_kl": zero} + + # Flatten to [N, V] and keep only valid tokens. + s_flat = student_logits.flatten(0, -2)[flat_mask] + t_flat = teacher_logits.flatten(0, -2)[flat_mask] + valid_labels = labels.reshape(-1)[flat_mask] + + # Diagnostic KL (always computed at T=1 on valid tokens). + s_logp = F.log_softmax(s_flat, dim=-1) + t_logp = F.log_softmax(t_flat, dim=-1) + s_p = s_logp.exp() + t_p = t_logp.exp() + forward_kl = self._kl_per_token(s_logp, t_p).mean() + backward_kl = self._kl_per_token(t_logp, s_p).mean() + + # Main KD loss according to loss_type. + if self.loss_type == "kl": + kd = forward_kl + elif self.loss_type == "rkl": + kd = backward_kl + elif self.loss_type == "mse": + kd = F.mse_loss(s_flat, t_flat) + elif self.loss_type == "kd": + # Legacy "kd": temperature-scaled forward KL. The outer trainer + # combines this value with the LM loss by configured weights. + temperature = max(self.kd_temperature, 1e-6) + kd = self._kl_per_token( + F.log_softmax(s_flat / temperature, dim=-1), + F.softmax(t_flat / temperature, dim=-1), + ).mean() * (temperature * temperature) + elif self.loss_type == "cakld": + # Contextual Asymmetric KL-divergence: per-token mixing of + # forward / reverse KL by teacher's confidence on the label. + per_tok_fkl = self._kl_per_token(s_logp, t_p) + per_tok_bkl = self._kl_per_token(t_logp, s_p) + conf = torch.gather(t_p, dim=-1, index=valid_labels.unsqueeze(-1)).squeeze(-1) + kd = (conf * per_tok_bkl + (1.0 - conf) * per_tok_fkl).mean() + elif self._is_reverse_topk_loss(self.loss_type): + topk = self._resolve_topk(self.loss_type, s_flat.size(-1)) + top_s, idx = s_flat.topk(topk, dim=-1, sorted=False) + top_t = t_flat.gather(-1, idx) + kd = self._kl_per_token( + F.log_softmax(top_t, dim=-1), + F.softmax(top_s, dim=-1), + ).mean() + elif self._is_forward_topk_loss(self.loss_type): + topk = self._resolve_topk(self.loss_type, t_flat.size(-1)) + top_t, idx = t_flat.topk(topk, dim=-1, sorted=False) + top_s = s_flat.gather(-1, idx) + kd = self._kl_per_token( + F.log_softmax(top_s, dim=-1), + F.softmax(top_t, dim=-1), + ).mean() + else: + raise ValueError( + f"Unsupported QAT kd loss_type: {self.loss_type}. " + "Valid: kl, rkl, mse, kd, cakld, kl_top[_K], r_kl_top[_K]." + ) + + return {"loss": kd, "forward_kl": forward_kl, "backward_kl": backward_kl} + + @staticmethod + def _is_forward_topk_loss(loss_type): + return loss_type.startswith("kl_top") + + @staticmethod + def _is_reverse_topk_loss(loss_type): + return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top") + + def _resolve_topk(self, loss_type, vocab_size): + topk = self.loss_topk + if topk is None and "_top_" in loss_type: + topk = int(loss_type.rsplit("_", 1)[-1]) + if topk is None: + topk = 1000 + topk = int(topk) + if topk <= 0: + raise ValueError(f"loss_topk must be positive, got: {topk}") + return min(topk, vocab_size) diff --git a/angelslim/compressor/qat/plugins/learnable_scale.py b/angelslim/compressor/qat/plugins/learnable_scale.py index 8f01514f..0f9ef3cc 100644 --- a/angelslim/compressor/qat/plugins/learnable_scale.py +++ b/angelslim/compressor/qat/plugins/learnable_scale.py @@ -15,7 +15,13 @@ import torch from tqdm import tqdm -from ....utils import print_info, set_op_by_name +from ....utils import ( + gathered_params_if_zero3, + is_deepspeed_zero3_enabled, + print_info, + set_op_by_name, + stream_load_scales, +) from ..modules.quantizer import QuantLinear from .base_plugin import BasePlugin from .plugin_manager import PluginManager @@ -29,11 +35,22 @@ @PluginManager.plugin("learnable_scale") class LearnableScalePlugin(BasePlugin): - def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, **kwargs): + def __init__( + self, + quant_info=None, + ignore_layers=None, + resume_ckpt_dir=None, + from_ptq_ckpt_dir=None, + **kwargs, + ): super().__init__(**kwargs) self.quant_info = quant_info self.ignore_layers = ignore_layers self.resume_ckpt_dir = resume_ckpt_dir + # Optional warm-start from a PTQ "real" checkpoint (only scales are + # read; base weights stay as loaded by from_pretrained). Required + # under DeepSpeed ZeRO-3. + self.from_ptq_ckpt_dir = from_ptq_ckpt_dir self.use_weight_quant = self.config.get("use_weight_quant", False) self.use_activation_quant = self.config.get("use_activation_quant", False) self.fp8_attn = self.config.get("fp8_attn", False) @@ -47,9 +64,23 @@ def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, ** self.learn_norm = learnable_cfg.get("norm", False) def before_train(self, **kwargs): + zero3 = is_deepspeed_zero3_enabled() + if zero3 and not self.from_ptq_ckpt_dir: + raise ValueError( + "DeepSpeed ZeRO-3 QAT requires `compression.QAT.from_ptq_ckpt` " + "to warm-start scales (lazy_init via forward is impossible " + "on sharded weights)." + ) + # Retrieve KV head count from model config for per-head quantization model_config = getattr(self.quant_model.model, "config", None) num_kv_heads = getattr(model_config, "num_key_value_heads", -1) + # Pre-allocate ``act_quantizer.scale`` as a Parameter whenever we + # plan to fill it from a checkpoint (full resume OR PTQ warm-start + # OR ZeRO-3 — where lazy_init is impossible). + act_preallocate = ( + self.resume_ckpt_dir is not None or self.from_ptq_ckpt_dir is not None or zero3 + ) for name, module in self.quant_model.model.named_modules(): if isinstance(module, torch.nn.Linear): if any(ig in name for ig in self.ignore_layers): @@ -67,7 +98,7 @@ def before_train(self, **kwargs): self.quant_info, self.use_weight_quant, self.use_activation_quant, - resume=self.resume_ckpt_dir is not None, + resume=act_preallocate, qkv_config=qkv_cfg, ) set_op_by_name(self.quant_model.model, name, q_linear) @@ -78,10 +109,18 @@ def before_train(self, **kwargs): print_info(self.quant_model.model) + # Warm-start scales from a previous PTQ "real" checkpoint. Only + # quantizer Parameters are touched; base Linear weights are NOT + # overwritten. + if self.from_ptq_ckpt_dir is not None: + stream_load_scales(self.quant_model.model, self.from_ptq_ckpt_dir) + if ( self.use_activation_quant and not q_linear.act_quantizer.dynamic and self.resume_ckpt_dir is None + and not zero3 + and self.from_ptq_ckpt_dir is None ): self._lazy_init(**kwargs) @@ -284,5 +323,13 @@ def _get_qkv_config_for_layer(name, quant_config): @torch.no_grad() def quant_inplace(model): for _, module in model.named_modules(): - if isinstance(module, QuantLinear): + if not isinstance(module, QuantLinear): + continue + # Gather the weight together with all weight_quantizer Parameters + # (scale / zero_point / optional LWC clip factors) so the + # fake-quant runs on the full materialised tensor under ZeRO-3. + params = [module.weight] + if hasattr(module, "weight_quantizer"): + params.extend(module.weight_quantizer.parameters(recurse=True)) + with gathered_params_if_zero3(params, modifier_rank=None): module.weight.data = module.weight_quantizer(module.weight.data) diff --git a/angelslim/compressor/qat/qat.py b/angelslim/compressor/qat/qat.py index a4274595..c5c524cd 100644 --- a/angelslim/compressor/qat/qat.py +++ b/angelslim/compressor/qat/qat.py @@ -17,7 +17,13 @@ import torch from safetensors.torch import save_file -from ...utils import print_info, set_op_by_name +from ...utils import ( + gathered_param_if_zero3, + model_has_zero3_params, + print_info, + save_via_model_save_func, + set_op_by_name, +) from ..compressor_factory import CompressorFactory from ..quant.modules.helper_layer import QDQModule from .modules.quantizer import QuantLinear @@ -38,6 +44,11 @@ def __init__(self, model, slim_config=None): self.quant_model.init_ptq(slim_config) self.quant_info = self.quant_model.quant_config self.plugin_manager = PluginManager() + # When set, ``save`` will use this rank-0-only state_dict instead of + # walking the model again. Populated by ``convert`` under ZeRO-3 to + # avoid keeping a full CPU copy of every layer's QDQModule on every + # rank. + self._rank0_state_dict = None self._init_plugins() self._init_trainer() @@ -57,6 +68,7 @@ def _init_plugins(self): quant_info=self.quant_info, ignore_layers=self.config["compress_config"].quantization.ignore_layers, resume_ckpt_dir=self.config["compress_config"].QAT.resume_ckpt_dir, + from_ptq_ckpt_dir=self.config["compress_config"].QAT.from_ptq_ckpt, config=self.plugin_config.get("quant_config", {}), quant_model=self.quant_model, ) @@ -72,46 +84,169 @@ def _init_trainer(self): def run(self, dataloader): self.trainer.run(dataloader) + @staticmethod + def _gather_clone(tensor): + """Detach + CPU-clone a tensor, gathering if it is a ZeRO-3 shard. + + WARNING: every rank gets a full CPU copy. Only safe for SMALL tensors + (e.g. scale Parameters). For large weights under ZeRO-3 use + ``_rank0_gather_clone`` instead. + """ + if tensor is None: + return None + with gathered_param_if_zero3(tensor): + return tensor.detach().cpu().clone() + + @staticmethod + def _sym_gather_clone(tensor): + """Symmetric gather-and-clone: every rank gets a full CPU copy. + + Collective timing is symmetric across ranks (minimising NCCL + stalls). Caller is responsible for dropping the clone on rank>0 + immediately to avoid keeping 'world_size' copies alive. + """ + if tensor is None: + return None + with gathered_param_if_zero3(tensor): + return tensor.detach().cpu().clone() + def convert(self): if self.save_fmt not in ("real", "real_and_kvcache"): return - print_info("Start QAT convert: replacing QuantLinear with QDQModule...") + zero3 = model_has_zero3_params(self.quant_model.model) + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 quant_algo = self.quant_info.quant_algo + if not zero3: + # ----- single-GPU / non-ZeRO-3 path: original behaviour ----- + print_info("Start QAT convert: replacing QuantLinear with QDQModule...") + for name, module in [ + (n, m) + for n, m in self.quant_model.model.named_modules() + if isinstance(m, QuantLinear) + ]: + weight_scale = ( + module.weight_quantizer.scale.data.clone() + if hasattr(module, "weight_quantizer") + else None + ) + input_scale = None + if module.use_act_quant and hasattr(module, "act_quantizer"): + aq = module.act_quantizer + if getattr(aq, "scale", None) is not None: + input_scale = aq.scale.data.clone() + qdq_module = QDQModule( + quant_algo=quant_algo, + weight=module.weight, + weight_scale=weight_scale, + bias=module.bias, + group_size=( + module.weight_quantizer.group_size + if hasattr(module, "weight_quantizer") + and hasattr(module.weight_quantizer, "group_size") + else 128 + ), + input_scale=input_scale, + ) + set_op_by_name(self.quant_model.model, name, qdq_module) + return + + # ----- ZeRO-3 path: every rank gathers + clones per layer (fast, + # NCCL-symmetric), but only rank0 keeps the data by feeding it into + # ``_rank0_state_dict``. rank>0 drops the clone immediately so peak + # CPU remains bounded by ~one layer's worth of tensors per rank. + # Model structure is NOT modified — we stream straight into the + # state_dict and let ``save_via_model_save_func`` patch + # ``state_dict()`` for the underlying save_func. + print_info( + f"[rank{rank}] Start QAT convert (ZeRO-3 mode: stream rank0 " + "state_dict, keep model structure intact)..." + ) + self._rank0_state_dict = {} if rank == 0 else {} + quant_linear_modules = [ - (name, module) - for name, module in self.quant_model.model.named_modules() - if isinstance(module, QuantLinear) + (n, m) for n, m in self.quant_model.model.named_modules() if isinstance(m, QuantLinear) ] + consumed_prefixes = set() for name, module in quant_linear_modules: + + # Symmetric gather: all ranks clone (memcpy, fast) so NCCL + # timing stays tight. This is ~world_size× transient CPU RAM + # for JUST this one layer; we free right after the rank0 + # branch completes. + weight = self._sym_gather_clone(module.weight) + bias = self._sym_gather_clone(getattr(module, "bias", None)) weight_scale = None if hasattr(module, "weight_quantizer"): - weight_scale = module.weight_quantizer.scale.data.clone() - + weight_scale = self._sym_gather_clone(module.weight_quantizer.scale) input_scale = None if module.use_act_quant and hasattr(module, "act_quantizer"): - act_quantizer = module.act_quantizer - if hasattr(act_quantizer, "scale") and act_quantizer.scale is not None: - input_scale = act_quantizer.scale.data.clone() + aq = module.act_quantizer + if getattr(aq, "scale", None) is not None: + input_scale = self._sym_gather_clone(aq.scale) + consumed_prefixes.add(name) + + if rank != 0: + # Drop the clone immediately; next iteration will overwrite + # these locals anyway but be explicit for clarity. + del weight, bias, weight_scale, input_scale + continue + + # rank0 only: run the fp8/int quantize path via a throwaway + # QDQModule, then move its params into the consolidated dict + # and discard the module. qdq_module = QDQModule( quant_algo=quant_algo, - weight=module.weight, + weight=weight, weight_scale=weight_scale, - bias=module.bias, + bias=bias, group_size=( module.weight_quantizer.group_size - if hasattr(module.weight_quantizer, "group_size") + if hasattr(module, "weight_quantizer") + and hasattr(module.weight_quantizer, "group_size") else 128 ), input_scale=input_scale, ) - set_op_by_name(self.quant_model.model, name, qdq_module) + for sub_name, p in qdq_module.named_parameters(recurse=False): + self._rank0_state_dict[f"{name}.{sub_name}"] = p.detach().cpu() + for sub_name, b in qdq_module.named_buffers(recurse=False): + if sub_name in qdq_module._non_persistent_buffers_set: + continue + self._rank0_state_dict[f"{name}.{sub_name}"] = b.detach().cpu() + del qdq_module, weight, bias, weight_scale, input_scale + + # Second pass: params/buffers that are NOT inside a QuantLinear + # (embeddings, lm_head, layernorms, MoE router gate, ...). The + # collective order MUST be identical across ranks, so this loop + # runs on every rank; only rank0 keeps the data. + for pname, param in self.quant_model.model.named_parameters(): + if any(pname.startswith(p + ".") for p in consumed_prefixes): + continue + with gathered_param_if_zero3(param): + if rank == 0: + self._rank0_state_dict[pname] = param.detach().cpu().clone() + + if rank == 0: + for module_name, mod in self.quant_model.model.named_modules(): + for buf_name, buf in mod.named_buffers(recurse=False): + if buf is None or buf_name in mod._non_persistent_buffers_set: + continue + full_key = f"{module_name}.{buf_name}" if module_name else buf_name + if full_key in self._rank0_state_dict: + continue + self._rank0_state_dict[full_key] = buf.detach().cpu().clone() + print_info( + f"[zero3] convert done: rank0 state_dict has " + f"{len(self._rank0_state_dict)} tensors." + ) def _save_kv_cache_scales(self, save_path: str): """Extract and save KV cache scales to a safetensors file.""" + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 kv_scales = {} for name, module in self.quant_model.model.named_modules(): if not isinstance(module, QuantLinear): @@ -126,8 +261,12 @@ def _save_kv_cache_scales(self, save_path: str): else: continue scale_key = f"{cache_name}.scale" - kv_scales[scale_key] = module.qkv_quantizer.scale.data.clone().float().cpu() + scale_tensor = self._gather_clone(module.qkv_quantizer.scale) + if scale_tensor is not None and rank == 0: + kv_scales[scale_key] = scale_tensor.float() + if rank != 0: + return os.makedirs(save_path, exist_ok=True) out_file = os.path.join(save_path, "kv_cache_scales.safetensors") save_file(kv_scales, out_file) @@ -146,7 +285,12 @@ def save(self, save_path: str): # "real": save real-quant model via model-specific save function elif self.save_fmt == "real": save_func = self.quant_model.get_save_func()(self.quant_model) - save_func.save(os.path.join(save_path, "final_quant_checkpoint")) + save_via_model_save_func( + self.quant_model, + save_func, + os.path.join(save_path, "final_quant_checkpoint"), + prebuilt_state_dict=self._rank0_state_dict, + ) # "save_kvcache_only": only export KV cache scales (kv_cache_scales.safetensors) elif self.save_fmt == "save_kvcache_only": @@ -155,7 +299,12 @@ def save(self, save_path: str): # "real_and_kvcache": save real-quant model AND KV cache scales elif self.save_fmt == "real_and_kvcache": save_func = self.quant_model.get_save_func()(self.quant_model) - save_func.save(os.path.join(save_path, "final_quant_checkpoint")) + save_via_model_save_func( + self.quant_model, + save_func, + os.path.join(save_path, "final_quant_checkpoint"), + prebuilt_state_dict=self._rank0_state_dict, + ) self._save_kv_cache_scales(os.path.join(save_path, "final_quant_checkpoint")) else: diff --git a/angelslim/compressor/qat/trainers/end2end_trainer.py b/angelslim/compressor/qat/trainers/end2end_trainer.py index fb5daf6d..57eac608 100644 --- a/angelslim/compressor/qat/trainers/end2end_trainer.py +++ b/angelslim/compressor/qat/trainers/end2end_trainer.py @@ -13,16 +13,33 @@ # limitations under the License. import torch -import torch.nn.functional as F from datasets import load_dataset from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments from ....data.qat_dataset import QATDataset -from ....utils import print_info +from ....utils import patch_deepspeed_duplicate_check, print_info +from ..plugins.distill_loss import DistillLoss from ..plugins.learnable_scale import set_quant_state from .trainer_factory import TrainerFactory +def _unique_named_params(model, predicate): + """Collect parameters matching ``predicate`` with id-based de-duplication. + + Some QAT setups share a single scale Parameter across multiple + QuantLinear views (e.g. MoE experts built from a shared tensor). HF / + DeepSpeed optimizer init rejects duplicates, so we de-dup by ``id``. + """ + seen = set() + result = [] + for name, param in model.named_parameters(): + if id(param) in seen or not predicate(name, param): + continue + seen.add(id(param)) + result.append(param) + return result + + class QATSeq2SeqTrainer(Seq2SeqTrainer): def __init__(self, *args, loss_config=None, quant_config=None, **kwargs): super().__init__(*args, **kwargs) @@ -31,93 +48,102 @@ def __init__(self, *args, loss_config=None, quant_config=None, **kwargs): self.loss_type = str(loss_config.get("loss_type", "origin")).lower() self.loss_topk = loss_config.get("loss_topk") self.kd_temperature = float(loss_config.get("kd_temperature", 1.0)) + # ``kd_alpha`` kept for backward compat but IGNORED when + # ``lm_loss_weight`` / ``kd_loss_weight`` are the (new) source of + # truth. self.kd_alpha = float(loss_config.get("kd_alpha", 0.5)) + self.lm_loss_weight = float(loss_config.get("lm_loss_weight", 1.0)) + self.kd_loss_weight = float(loss_config.get("kd_loss_weight", 0.0)) + self.distill_loss = DistillLoss( + loss_type=self.loss_type, + loss_topk=self.loss_topk, + kd_temperature=self.kd_temperature, + ) self.use_weight_quant = quant_config.get("use_weight_quant", False) self.use_activation_quant = quant_config.get("use_activation_quant", False) self.use_qkv_quant = quant_config.get("use_qkv_quant", False) + # Running metric aggregator keyed by logger mode. + from collections import defaultdict + + self._qat_metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + + def _record(self, name, value): + if value is None: + return + mode = "train" if self.model.training else "eval" + v = value.detach().float() if isinstance(value, torch.Tensor) else float(value) + if isinstance(v, torch.Tensor): + self._qat_metrics[mode][name].append(v.item()) + else: + self._qat_metrics[mode][name].append(v) + + # ------------------------------------------------------------------ + # compute_loss + # ------------------------------------------------------------------ + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + labels = inputs.get("labels", None) + lm_on = self.lm_loss_weight > 0.0 + kd_on = self.kd_loss_weight > 0.0 + + # Back-compat: ``loss_type="origin"`` means "pure HF CE loss" → the + # classic SFT path with no distillation. Honour it even when the + # user forgot to set kd_loss_weight=0. if self.loss_type == "origin": - return super().compute_loss( - model, - inputs, - return_outputs=return_outputs, - num_items_in_batch=num_items_in_batch, - ) + kd_on = False + + if not lm_on and not kd_on: + raise ValueError("Both lm_loss_weight and kd_loss_weight are 0 — nothing to optimise.") - teacher_logits = self.get_ori_outputs(model, inputs).logits + # Student forward — always needed. + # HF CausalLM loss is computed when ``labels`` is present in inputs. student_inputs = dict(inputs) - if self.loss_type != "kd": + if not lm_on: + # Still need labels for flat_mask; pop from student kwargs to + # skip HF's internal CE and save some compute. student_inputs.pop("labels", None) outputs = model(**student_inputs) - student_logits = outputs.logits - if self.loss_type == "kl": - loss = F.kl_div( - F.log_softmax(student_logits.flatten(0, -2), dim=-1), - F.softmax(teacher_logits.flatten(0, -2), dim=-1), - reduction="batchmean", - ) - elif self.loss_type == "rkl": - loss = F.kl_div( - F.log_softmax(teacher_logits.flatten(0, -2), dim=-1), - F.softmax(student_logits.flatten(0, -2), dim=-1), - reduction="batchmean", + lm_loss = outputs.loss if lm_on and getattr(outputs, "loss", None) is not None else None + if lm_on and lm_loss is None: + raise ValueError( + "lm_loss_weight > 0 but model did not return a loss — " + "check that ``labels`` is set in the batch." ) - elif self.loss_type == "mse": - loss = F.mse_loss(student_logits, teacher_logits) - elif self.loss_type == "kd": - if getattr(outputs, "loss", None) is None: - raise ValueError("loss_type='kd' requires labels to compute CE loss.") - temperature = max(self.kd_temperature, 1e-6) - alpha = self.kd_alpha - distill_loss = F.kl_div( - F.log_softmax(student_logits.flatten(0, -2) / temperature, dim=-1), - F.softmax(teacher_logits.flatten(0, -2) / temperature, dim=-1), - reduction="batchmean", - ) - loss = outputs.loss * (1 - alpha) + distill_loss * (alpha * temperature * temperature) - elif self._is_reverse_topk_loss(self.loss_type): - topk = self._resolve_topk(self.loss_type, student_logits.size(-1)) - top_student_logits, indices = student_logits.topk(topk, dim=-1, sorted=False) - top_teacher_logits = teacher_logits.gather(-1, indices) - loss = F.kl_div( - F.log_softmax(top_teacher_logits.flatten(0, -2), dim=-1), - F.softmax(top_student_logits.flatten(0, -2), dim=-1), - reduction="batchmean", - ) - elif self._is_forward_topk_loss(self.loss_type): - topk = self._resolve_topk(self.loss_type, teacher_logits.size(-1)) - top_teacher_logits, indices = teacher_logits.topk(topk, dim=-1, sorted=False) - top_student_logits = student_logits.gather(-1, indices) - loss = F.kl_div( - F.log_softmax(top_student_logits.flatten(0, -2), dim=-1), - F.softmax(top_teacher_logits.flatten(0, -2), dim=-1), - reduction="batchmean", - ) - else: - raise ValueError(f"Unsupported QAT loss_type: {self.loss_type}") - return (loss, outputs) if return_outputs else loss + kd_info = None + if kd_on: + if labels is None: + raise ValueError("kd_loss_weight > 0 requires ``labels`` in the batch.") + teacher_logits = self.get_ori_outputs(model, inputs).logits + kd_info = self.distill_loss.compute( + outputs.logits, + teacher_logits, + labels, + ) - @staticmethod - def _is_forward_topk_loss(loss_type): - return loss_type.startswith("kl_top") + # Combine. + total = outputs.logits.new_zeros(()) + if lm_loss is not None: + total = total + self.lm_loss_weight * lm_loss + if kd_info is not None: + total = total + self.kd_loss_weight * kd_info["loss"] - @staticmethod - def _is_reverse_topk_loss(loss_type): - return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top") + # Logging: record every component whose weight is > 0, plus the + # always-informative forward/backward KL diagnostics when kd is on. + if lm_on and lm_loss is not None: + self._record("lm_loss", lm_loss) + if kd_on and kd_info is not None: + self._record(f"kd/{self.loss_type}", kd_info["loss"]) + # Diagnostic KL(L/R) for any kd variant. Useful to monitor + # teacher-student disagreement independent of the combined + # objective. + self._record("kd/forward_kl", kd_info["forward_kl"]) + self._record("kd/backward_kl", kd_info["backward_kl"]) + self._record("total_loss", total) - def _resolve_topk(self, loss_type, vocab_size): - topk = self.loss_topk - if topk is None and "_top_" in loss_type: - topk = int(loss_type.rsplit("_", 1)[-1]) - if topk is None: - topk = 1000 - topk = int(topk) - if topk <= 0: - raise ValueError(f"loss_topk must be positive, got: {topk}") - return min(topk, vocab_size) + return (total, outputs) if return_outputs else total @torch.no_grad() def get_ori_outputs(self, model, inputs): @@ -136,6 +162,31 @@ def get_ori_outputs(self, model, inputs): ) return outputs + def log(self, logs, start_time=None, *args, **kwargs): + """Inject running QAT loss components (lm_loss / kd/... / kd/forward_kl + / kd/backward_kl / total_loss) into HuggingFace Trainer's log dict. + + Each value is averaged across the steps accumulated since the + previous ``log`` call, then the accumulator is cleared — matching + HF Trainer's behaviour for its built-in ``loss`` key. + """ + mode = "train" if self.model.training else "eval" + bucket = self._qat_metrics.get(mode, {}) + if bucket: + for key, vals in bucket.items(): + if not vals: + continue + avg = sum(vals) / len(vals) + out_key = key if mode == "train" else f"eval_{key}" + logs[out_key] = float(avg) + bucket.clear() + + # Forward to HF Trainer's log. Signature differs across versions. + try: + return super().log(logs, start_time, *args, **kwargs) + except TypeError: + return super().log(logs) + @TrainerFactory.register("end2end") class End2EndTrainer: @@ -157,6 +208,8 @@ def __init__(self, quant_model, config, plugin_manager): "loss_topk": config["compress_config"].QAT.loss_topk, "kd_temperature": config["compress_config"].QAT.kd_temperature, "kd_alpha": config["compress_config"].QAT.kd_alpha, + "lm_loss_weight": config["compress_config"].QAT.lm_loss_weight, + "kd_loss_weight": config["compress_config"].QAT.kd_loss_weight, } self.quant_config = { "use_weight_quant": config["compress_config"] @@ -173,13 +226,13 @@ def __init__(self, quant_model, config, plugin_manager): def _init_optimizer(self): lr = float(self.config["compress_config"].QAT.hf_args.get("learning_rate", 1e-5)) wd = float(self.config["compress_config"].QAT.hf_args.get("weight_decay", 0)) + scale_params = _unique_named_params( + self.quant_model.model, + lambda n, p: p.requires_grad and ("scale" in n or "zero_point" in n), + ) params = [ { - "params": [ - p - for n, p in self.quant_model.model.named_parameters() - if "scale" in n or "zero_point" in n - ], + "params": scale_params, "weight_decay": wd, "lr": lr, } @@ -191,6 +244,7 @@ def _init_optimizer(self): .get("lwc", {}) .get("enable_lwc", False) ) + lwc_param_count = 0 if enable_lwc: lwc_lr = float( self.config["compress_config"] @@ -198,30 +252,40 @@ def _init_optimizer(self): .get("lwc", {}) .get("lwc_lr", 1e-1) ) - lwc_params = [ - { - "params": [ - p - for n, p in self.quant_model.model.named_parameters() - if "clip_factor_w_max" in n or "clip_factor_w_min" in n - ], - "weight_decay": wd, - "lr": lwc_lr, - } - ] - params.extend(lwc_params) + lwc_params = _unique_named_params( + self.quant_model.model, + lambda n, p: p.requires_grad + and ("clip_factor_w_max" in n or "clip_factor_w_min" in n), + ) + lwc_param_count = len(lwc_params) + params.append({"params": lwc_params, "weight_decay": wd, "lr": lwc_lr}) + + if not any(group["params"] for group in params): + raise ValueError("QAT optimizer has no trainable parameters.") self.optimizer = torch.optim.AdamW(params) if enable_lwc: - print_info(f"Init optimizer with learnable lr={lr} lwc_lr={lwc_lr} weight_decay={wd}") + print_info( + f"Init optimizer with {len(scale_params)} scale params, " + f"{lwc_param_count} lwc params, lr={lr} lwc_lr={lwc_lr} weight_decay={wd}" + ) else: - print_info(f"Init optimizer with learnable lr={lr} weight_decay={wd}") + print_info( + f"Init optimizer with {len(scale_params)} scale params, " + f"lr={lr} weight_decay={wd}" + ) def prepare_trainer(self): if self.training_mode == "blockwise": return if self.training_mode == "end2end" and self.dist_mode == "hf": self._init_optimizer() + # When DeepSpeed is used, neutralize its duplicate-parameter + # check: it rejects param-groups that share tensors, which our + # scale/zero_point setup can legally have (shared tensors + # across views). Idempotent and a no-op if deepspeed is absent. + if self.config["compress_config"].QAT.hf_args.get("deepspeed") is not None: + patch_deepspeed_duplicate_check() self.external_trainer = QATSeq2SeqTrainer( model=self.quant_model.model, processing_class=self.quant_model.tokenizer, diff --git a/angelslim/data/dataloader.py b/angelslim/data/dataloader.py index a3dcc0a6..72488ce1 100644 --- a/angelslim/data/dataloader.py +++ b/angelslim/data/dataloader.py @@ -44,6 +44,7 @@ def create_data_loader( use_audio_in_video: bool = False, model_name: str = None, quantization_config: str = None, + is_sft_data: bool = False, ) -> DataLoader: """ Create appropriate DataLoader based on data source @@ -85,6 +86,7 @@ def create_data_loader( max_length=max_length, data_path=data_source, num_samples=num_samples, + is_sft_data=is_sft_data, ) elif data_type == "MultiModalDataset": dataset = MultiModalDataset( diff --git a/angelslim/data/text_dataset.py b/angelslim/data/text_dataset.py index 510b3a05..96e5a874 100644 --- a/angelslim/data/text_dataset.py +++ b/angelslim/data/text_dataset.py @@ -34,8 +34,10 @@ def __init__( device: str = "cpu", max_length: int = 4096, num_samples: int = -1, + is_sft_data: bool = False, ): super().__init__(processor, device, max_length) + self.is_sft_data = is_sft_data self._load_data(data_path, num_samples) def _load_data(self, data_path: str, num_samples: int): @@ -74,9 +76,8 @@ def _load_hf_dataset(self, data_path: str, num_samples: int, block_size: int = 2 inputs = { "input_ids": torch.tensor(result["input_ids"][i]).unsqueeze(0).to(self.device) } - labels = inputs["input_ids"].roll(shifts=-1, dims=-1) - labels[:, -1] = -100 - inputs["labels"] = labels.to(self.device) + # HF CausalLM models shift labels internally; feed labels == input_ids. + inputs["labels"] = inputs["input_ids"].clone() inputs["attention_mask"] = torch.tensor(result["attention_mask"][i]).to(self.device) self.data.append(inputs) @@ -101,8 +102,8 @@ def _load_parquet_data(self, data_path: str, num_samples: int): if "labels" in df.columns: labels = torch.tensor(df["labels"].iloc[i]).unsqueeze(0) else: - labels = model_inputs["input_ids"].roll(shifts=-1, dims=-1) - labels[:, -1] = -100 + # HF CausalLM models shift labels internally; feed labels == input_ids. + labels = model_inputs["input_ids"].clone() data_item = { "input_ids": model_inputs["input_ids"].to(self.device), @@ -128,42 +129,83 @@ def _load_jsonl_data(self, data_path: str, num_samples: int): # Prepare messages messages = self._prepare_messages(data) - # Apply chat template - text = self.processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + # Find the LAST assistant turn — loss is computed ONLY on + # this reply. Everything before it (system + user(s) + + # earlier assistant(s)) serves as prompt context. + last_assistant_idx = None + for idx, item in enumerate(messages): + if item["role"] == "assistant": + last_assistant_idx = idx + if last_assistant_idx is None: + # No assistant turn -> nothing to supervise; skip. + continue + prompt_messages = messages[:last_assistant_idx] + assistant_msg = messages[last_assistant_idx] + + # Tokenize the prompt (up to the generation marker) and the + # full conversation separately so we know exactly where the + # assistant reply starts. + prompt_text = self.processor.apply_chat_template( + prompt_messages, tokenize=False, add_generation_prompt=True + ) + full_messages = prompt_messages + [assistant_msg] + full_text = self.processor.apply_chat_template( + full_messages, tokenize=False, add_generation_prompt=False ) - thinking_data = False - for dic in messages: - if dic["role"] == "assistant": - if "" and "" in dic["content"]: - thinking_data = True - break + # Legacy branch: thinking-style data without a chat template. + thinking_data = any( + m["role"] == "assistant" + and "" in m.get("content", "") + and "" in m.get("content", "") + for m in messages + ) if thinking_data: - text = self.processor.bos_token if self.processor.bos_token is not None else "" - for dic in messages: - if dic["role"] == "system": - text += dic["content"] - elif dic["role"] == "user": - text = text + "<|User|>" + dic["content"] + "<|Assistant|>" - elif dic["role"] == "assistant": - text = text + dic["content"] + self.processor.eos_token + bos = self.processor.bos_token or "" + prompt_text = bos + for m in prompt_messages: + if m["role"] == "system": + prompt_text += m["content"] + elif m["role"] == "user": + prompt_text += "<|User|>" + m["content"] + "<|Assistant|>" + elif m["role"] == "assistant": + prompt_text += m["content"] + self.processor.eos_token + full_text = prompt_text + assistant_msg["content"] + self.processor.eos_token + + # Token-level prompt length: count tokens in ``prompt_text`` + # without special-token insertion so it aligns with the + # prefix of the tokenization of ``full_text``. + prompt_ids = self.processor( + prompt_text, + add_special_tokens=False, + return_tensors=None, + )["input_ids"] + prompt_len = len(prompt_ids) model_inputs = self.processor( - text=[text], + text=[full_text], return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length", ) - labels = model_inputs["input_ids"].roll(shifts=-1, dims=-1) - labels[:, -1] = -100 + # Build labels: HF CausalLM shifts labels internally, so + # the label at position ``t`` supervises the prediction of + # ``input_ids[t+1]``. Positions before (and at) the end of + # the prompt are set to -100 so they contribute no loss. + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + labels = input_ids.clone() + if self.is_sft_data: + labels[:, :prompt_len] = -100 + # Also mask padding tokens. + labels[attention_mask == 0] = -100 self.data.append( { - "input_ids": model_inputs["input_ids"].to(self.device), - "attention_mask": model_inputs["attention_mask"].to(self.device), + "input_ids": input_ids.to(self.device), + "attention_mask": attention_mask.to(self.device), "labels": labels.to(self.device), } ) diff --git a/angelslim/engine.py b/angelslim/engine.py index 757d8741..8c5344c2 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -101,6 +101,13 @@ def prepare_model( assert model_name, "model_name must be specified." assert model_path, "model_path must be specified." + # Normalize device_map for DeepSpeed ZeRO / distributed training: YAML + # configs often write ``None`` / ``"None"`` / ``"distributed"`` to + # mean "no pre-placement, let DeepSpeed shard". HF only accepts + # Python ``None`` there. + if isinstance(device_map, str) and device_map.lower() in ("none", "distributed"): + device_map = None + # Initialize slim model by ModelFactory self.slim_model = SlimModelFactory.create( model_name, model=model, deploy_backend=deploy_backend @@ -152,6 +159,7 @@ def prepare_data( use_audio_in_video=False, model_name=None, quantization_config=None, + is_sft_data=False, ) -> Optional[Any]: """Prepare compression dataset""" if custom_dataloader is not None: @@ -178,6 +186,7 @@ def prepare_data( use_audio_in_video=use_audio_in_video, model_name=model_name, quantization_config=quantization_config, + is_sft_data=is_sft_data, ) self.max_seq_length = max_length diff --git a/angelslim/models/base_model.py b/angelslim/models/base_model.py index 3f15befb..39ee6ece 100644 --- a/angelslim/models/base_model.py +++ b/angelslim/models/base_model.py @@ -24,7 +24,13 @@ from ..compressor.quant.core import QuantConfig from ..compressor.quant.modules import NVFP4QDQModule, QDQModule -from ..utils import common_prefix, print_info +from ..utils import ( + common_prefix, + is_deepspeed_zero3_enabled, + print_info, + stream_load_weights, + zero3_empty_model_from_pretrained, +) __all__ = ["BaseLLMModel"] @@ -65,6 +71,27 @@ def from_pretrained( using_multi_nodes=False, attn_implementation="default", ): + # DeepSpeed ZeRO-3 path: build an empty sharded model on every rank + # via deepspeed.zero.Init, linearize fused MoE experts, then stream + # the safetensors checkpoint into the (sharded) parameters. This + # avoids HF's path that materialises the full state_dict on every + # rank's CPU before sharding. + if is_deepspeed_zero3_enabled(): + log_prefix = f"[{type(self).__name__}.from_pretrained]" + self.model = zero3_empty_model_from_pretrained( + model_path, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + use_cache=use_cache, + attn_implementation=attn_implementation, + log_prefix=log_prefix, + ) + stream_load_weights(self.model, model_path, log_prefix=log_prefix) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + return + kwargs = dict( torch_dtype=torch_dtype, device_map=device_map, diff --git a/angelslim/models/llm/hunyuan_v3_moe.py b/angelslim/models/llm/hunyuan_v3_moe.py index d8919643..fdfd9c5c 100644 --- a/angelslim/models/llm/hunyuan_v3_moe.py +++ b/angelslim/models/llm/hunyuan_v3_moe.py @@ -16,20 +16,50 @@ import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.hy_v3.modeling_hy_v3 import ( ALL_ATTENTION_FUNCTIONS, HYV3Experts, + HYV3TopKRouter, apply_rotary_pos_emb, eager_attention_forward, ) from ...compressor.quant.core import PTQSaveVllmHF +from ...utils import is_deepspeed_zero3_enabled from ...utils.utils import find_layers, find_parent_layer_and_sub_name from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory +def _patch_hyv3_router_for_zero3(): + if getattr(HYV3TopKRouter, "_angelslim_zero3_dtype_patch", False): + return + + def patched_forward( + self, + hidden_states: torch.Tensor, + e_score_correction_bias: torch.Tensor, + ) -> tuple: + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = nn.functional.linear( + hidden_states.to(self.weight.dtype), + self.weight, + ).to(torch.float32) + routing_weights = torch.sigmoid(router_logits) + + scores_for_choice = routing_weights + e_score_correction_bias + _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False) + top_k_weights = routing_weights.gather(1, top_k_index) + + top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20) + top_k_weights = top_k_weights * self.router_scaling_factor + + return router_logits, top_k_weights, top_k_index + + HYV3TopKRouter.forward = patched_forward + HYV3TopKRouter._angelslim_zero3_dtype_patch = True + + class HYV3ExpertsWithLinear(HYV3Experts): """Wrapper around HYV3Experts that exposes per-expert weights as nn.Linear modules. @@ -138,19 +168,18 @@ def from_pretrained( ): attn_implementation = "eager" torch_dtype = torch.bfloat16 - self.model = AutoModelForCausalLM.from_pretrained( - model_path, - attn_implementation=attn_implementation, + if is_deepspeed_zero3_enabled(): + _patch_hyv3_router_for_zero3() + + super().from_pretrained( + model_path=model_path, torch_dtype=torch_dtype, device_map=device_map, trust_remote_code=trust_remote_code, low_cpu_mem_usage=low_cpu_mem_usage, use_cache=use_cache, - ) - - # Load tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=trust_remote_code + using_multi_nodes=using_multi_nodes, + attn_implementation=attn_implementation, ) def replace_moe(self): @@ -179,9 +208,9 @@ def get_observer_layers(self): "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj", - "shared_mlp.gate_proj", - "shared_mlp.up_proj", - "shared_mlp.down_proj", + "shared_experts.gate_proj", + "shared_experts.up_proj", + "shared_experts.down_proj", ] expert_pattern = [ r"model\.layers\.\d+\.mlp\.experts\.\d+\.gate_proj", @@ -326,11 +355,9 @@ def patched_forward( key_states, value_states, attn_module.layer_idx, cache_kwargs ) - attention_interface = eager_attention_forward - if attn_module.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[ - attn_module.config._attn_implementation - ] + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + attn_module.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( attn_module, diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index d80be244..a12f24b5 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -30,3 +30,4 @@ from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 +from .zero3_io import * # noqa: F401 F403 diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index 77481b9e..31b5acdc 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -185,6 +185,7 @@ class DatasetConfig: batch_size: int = field(default=1) shuffle: bool = field(default=False) inference_settings: Optional[Dict[str, Any]] = field(default=None) + is_sft_data: bool = field(default=False) @dataclass @@ -272,10 +273,23 @@ class QATTrainingConfig: hf_dataset: Optional[str] = None do_train: bool = field(default=True) resume_ckpt_dir: Optional[str] = None + # Optional warm-start directory (a previous ``save_fmt="real"`` output). + # When set, scales / zero_points / kv-cache scales are loaded from this + # directory into the freshly-created ``QuantLinear`` quantizers; base + # ``Linear`` weights still come from ``model.model_path``. Required when + # DeepSpeed ZeRO-3 is enabled (no calibration is possible on shards). + from_ptq_ckpt: Optional[str] = None loss_type: str = field(default="origin") loss_topk: Optional[int] = None kd_temperature: float = field(default=1.0) kd_alpha: float = field(default=0.5) + # ---- new loss-weight controls (compose LM + KD loss) ---- + # lm_loss_weight: weight on the HF CausalLM CE loss (labels must be set). + # kd_loss_weight: weight on the chosen distillation loss (kl / rkl / mse + # / kl_top_K / r_kl_top_K / cakld / jsd ...). Only loss components + # with a strictly-positive weight are computed AND logged. + lm_loss_weight: float = field(default=1.0) + kd_loss_weight: float = field(default=0.0) hf_args: Dict[str, Any] = field(default_factory=dict) diff --git a/angelslim/utils/utils.py b/angelslim/utils/utils.py index 36422758..e10f2e98 100644 --- a/angelslim/utils/utils.py +++ b/angelslim/utils/utils.py @@ -49,10 +49,18 @@ def set_op_by_name(layer, name, new_module): if len(levels) > 1: mod_ = layer for l_idx in range(len(levels) - 1): - if levels[l_idx].isdigit(): - mod_ = mod_[int(levels[l_idx])] + part = levels[l_idx] + if part.isdigit(): + # Prefer integer indexing for nn.ModuleList / nn.Sequential; + # fall back to getattr for custom containers (e.g. our + # LinearizedMoeExperts that registers experts via + # ``setattr(self, str(idx), ...)``). + try: + mod_ = mod_[int(part)] + except (TypeError, IndexError, KeyError): + mod_ = getattr(mod_, part) else: - mod_ = getattr(mod_, levels[l_idx]) + mod_ = getattr(mod_, part) setattr(mod_, levels[-1], new_module) else: setattr(layer, name, new_module) diff --git a/angelslim/utils/zero3_io.py b/angelslim/utils/zero3_io.py new file mode 100644 index 00000000..614ee5bd --- /dev/null +++ b/angelslim/utils/zero3_io.py @@ -0,0 +1,853 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All-in-one DeepSpeed ZeRO-3 helpers for AngelSlim QAT. + +This module concentrates everything ZeRO-3-specific so that the rest of the +codebase touches only a few thin call-sites: + + * ZeRO-3 detection / parameter gathering helpers. + * Empty model construction under ``deepspeed.zero.Init`` plus a generic + "linearize fused MoE experts" pass that builds *empty* per-expert + ``nn.Linear`` modules (no copy from the old fused tensor required — + weights are filled later from the safetensors checkpoint). + * Streaming weight loader: walks the safetensors shards once, on rank 0 + only, and broadcasts each tensor into the (possibly sharded) target + parameter via ``GatheredParameters(modifier_rank=0)``. Handles fused + MoE keys (``...experts.gate_up_proj``, ``...experts.down_proj``) by + slicing per expert into the linearized targets. + * Streaming scale loader for QAT warm-start: reads ``*.weight_scale``, + ``*.input_scale``, ``*.k_cache.scale``, ``*.v_cache.scale`` keys from a + PTQ "real" checkpoint and writes them into the freshly-created + ``QuantLinear`` quantizer parameters. + * Saving: gather a sharded model into a rank-0 CPU state_dict and call + the model-specific save_func by patching ``state_dict``. + * Optimizer-side patches needed because QAT scale parameters are tied + across multiple modules in a layer. + +By design, **nothing in this file mutates the model when ZeRO-3 is not +enabled** (each helper is a no-op or behaves identically to the +non-distributed path). Importing this module is therefore safe in any +configuration. +""" + +from __future__ import annotations + +import gc +import glob +import json +import os +from contextlib import contextmanager, nullcontext +from typing import Optional + +import torch +import torch.nn as nn + +from .lazy_imports import deepspeed +from .utils import find_parent_layer_and_sub_name, print_info + +ZERO3_PARAM_ATTRS = ("ds_id", "ds_status", "ds_numel", "ds_tensor") + + +# --------------------------------------------------------------------------- +# Basic detection / context helpers +# --------------------------------------------------------------------------- + + +def is_deepspeed_zero3_enabled() -> bool: + """True iff HuggingFace's ``HfTrainerDeepSpeedConfig`` is registered with + ZeRO stage 3. Returns False if ``transformers``/``deepspeed`` is not + importable.""" + try: + from transformers.integrations.deepspeed import ( + is_deepspeed_zero3_enabled as _hf, + ) + + return bool(_hf()) + except Exception: # noqa: BLE001 + return False + + +def is_zero3_param(x) -> bool: + """True iff ``x`` is a ZeRO-3 sharded parameter (``deepspeed.zero.Init`` + has injected its bookkeeping attributes).""" + if not isinstance(x, torch.nn.Parameter): + return False + return any(hasattr(x, attr) for attr in ZERO3_PARAM_ATTRS) + + +@contextmanager +def gathered_param_if_zero3(x, modifier_rank: Optional[int] = None): + """All-gather a ZeRO-3 shard for the lifetime of the block. + + Pure no-op (yields ``x`` unchanged) when ``x`` is not a ZeRO-3 shard, so + callers can always wrap their critical sections without branching. + """ + if is_zero3_param(x): + ctx = deepspeed.zero.GatheredParameters([x], modifier_rank=modifier_rank) + else: + ctx = nullcontext() + with ctx: + yield x + + +@contextmanager +def gathered_params_if_zero3(params, modifier_rank: Optional[int] = None): + """Batched variant of :func:`gathered_param_if_zero3`.""" + params = [p for p in params if p is not None] + z3 = [p for p in params if is_zero3_param(p)] + if z3: + ctx = deepspeed.zero.GatheredParameters(z3, modifier_rank=modifier_rank) + else: + ctx = nullcontext() + with ctx: + yield params + + +def model_has_zero3_params(model) -> bool: + return any(is_zero3_param(p) for p in model.parameters()) + + +def _rank() -> int: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return 0 + + +def _cleanup() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Empty MoE linearization +# --------------------------------------------------------------------------- + + +class LinearizedMoeExperts(nn.Module): + """Empty per-expert ``nn.Linear`` container for fused MoE experts. + + This module mirrors the ``forward`` of HuggingFace's fused experts but + holds parameters as ``num_experts`` triplets of ``(gate_proj, up_proj, + down_proj)`` ``nn.Linear`` modules. Construction is **purely structural** + — no weight copy from the old fused tensor — so it is safe to instantiate + under ``deepspeed.zero.Init`` and let the streaming loader fill in the + weights afterwards. + """ + + _angelslim_linearized_moe = True + + def __init__( + self, + num_experts: int, + hidden_dim: int, + intermediate_dim: int, + act_fn, + dtype=torch.bfloat16, + device=None, + config=None, + ): + super().__init__() + self.num_experts = int(num_experts) + self.hidden_dim = int(hidden_dim) + self.intermediate_dim = int(intermediate_dim) + self.act_fn = act_fn + if config is not None: + self.config = config + + if device is None or (isinstance(device, torch.device) and device.type == "meta"): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for expert_idx in range(self.num_experts): + expert = nn.ModuleDict( + { + "gate_proj": nn.Linear( + self.hidden_dim, + self.intermediate_dim, + bias=False, + dtype=dtype, + device=device, + ), + "up_proj": nn.Linear( + self.hidden_dim, + self.intermediate_dim, + bias=False, + dtype=dtype, + device=device, + ), + "down_proj": nn.Linear( + self.intermediate_dim, + self.hidden_dim, + bias=False, + dtype=dtype, + device=device, + ), + } + ) + setattr(self, str(expert_idx), expert) + + def __getitem__(self, idx): + return getattr(self, str(idx)) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit_mask = torch.greater(expert_mask.sum(dim=(-1, -2)), 0) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + # ZeRO-3 gathers parameters when a Linear is entered. All ranks must + # enter the same experts in the same order even when their local + # batches route to different experts, otherwise collectives deadlock. + expert_hit_int = expert_hit_mask.to(torch.int32) + torch.distributed.all_reduce( + expert_hit_int, + op=torch.distributed.ReduceOp.MAX, + ) + expert_hit_mask = expert_hit_int.to(torch.bool) + expert_hit = expert_hit_mask.nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + expert_layer = getattr(self, str(int(expert_idx.item()))) + gate = expert_layer["gate_proj"](current_state) + up = expert_layer["up_proj"](current_state) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = expert_layer["down_proj"](current_hidden_states) + current_hidden_states = ( + current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + ) + final_hidden_states.index_add_( + 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) + ) + + return final_hidden_states + + +def _is_fused_moe_experts(module) -> bool: + """Heuristic: matches HF Qwen3MoeExperts / HYV3Experts / similar. + + They all expose ``gate_up_proj`` / ``down_proj`` as ``nn.Parameter`` + plus ``num_experts``, ``hidden_dim``, ``intermediate_dim``, ``act_fn``.""" + if isinstance(module, LinearizedMoeExperts): + return False + required = ( + "gate_up_proj", + "down_proj", + "num_experts", + "hidden_dim", + "intermediate_dim", + "act_fn", + ) + return all(hasattr(module, a) for a in required) + + +def _ds_full_shape(param): + """Full shape of a parameter, accounting for ZeRO-3 sharding.""" + shape = getattr(param, "ds_shape", None) + if shape is None: + shape = param.shape + return tuple(int(x) for x in shape) + + +def linearize_moe_experts_empty(model, dtype=torch.bfloat16) -> int: + """Replace every fused MoE experts module in ``model`` with an empty + :class:`LinearizedMoeExperts`. + + Under ZeRO-3 the new ``nn.Linear`` parameters are created inside a + ``deepspeed.zero.Init`` context so they get partitioned immediately. + Weights are NOT copied from the old fused tensors — the streaming + safetensors loader is responsible for that. The old module is dropped + afterwards. + """ + targets = [] + for name, module in model.named_modules(): + if _is_fused_moe_experts(module): + targets.append(name) + + if not targets: + return 0 + + z3 = is_deepspeed_zero3_enabled() + replaced = 0 + for name in targets: + parent, sub_name = find_parent_layer_and_sub_name(model, name) + old = getattr(parent, sub_name) + + # Resolve dimensions from the (possibly sharded) old tensors. + gate_up_shape = _ds_full_shape(old.gate_up_proj) + down_shape = _ds_full_shape(old.down_proj) + num_experts = int(old.num_experts) + hidden_dim = int(old.hidden_dim) + # gate_up_proj: [num_experts, 2*intermediate_dim, hidden_dim] + if len(gate_up_shape) >= 3 and gate_up_shape[-2] % 2 == 0: + intermediate_dim = gate_up_shape[-2] // 2 + elif len(down_shape) >= 3: + intermediate_dim = int(down_shape[-1]) + else: + intermediate_dim = int(old.intermediate_dim) + + old_dtype = old.gate_up_proj.dtype if hasattr(old.gate_up_proj, "dtype") else dtype + config = getattr(old, "config", None) + act_fn = old.act_fn + + ctx = deepspeed.zero.Init() if z3 else nullcontext() + with ctx: + new_module = LinearizedMoeExperts( + num_experts=num_experts, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + act_fn=act_fn, + dtype=old_dtype, + device=None, # let LinearizedMoeExperts pick cuda + config=config, + ) + + setattr(parent, sub_name, new_module) + del old + replaced += 1 + _cleanup() + + print_info(f"[zero3] linearize_moe_experts_empty: replaced {replaced} fused module(s).") + return replaced + + +# --------------------------------------------------------------------------- +# Empty model construction + streaming weight loader +# --------------------------------------------------------------------------- + + +def _resolve_dtype(torch_dtype, config): + if isinstance(torch_dtype, torch.dtype): + return torch_dtype + if isinstance(torch_dtype, str) and torch_dtype != "auto": + return getattr(torch, torch_dtype) + resolved = getattr(config, "torch_dtype", None) or torch.float32 + if isinstance(resolved, str): + return getattr(torch, resolved) + return resolved + + +def zero3_empty_model_from_pretrained( + model_path, + torch_dtype="auto", + trust_remote_code=True, + use_cache=False, + attn_implementation="default", + log_prefix="[zero3]", +): + """Build an EMPTY ZeRO-3 sharded model from a HuggingFace ``model_path``. + + Linearizes all fused MoE experts immediately so subsequent QuantLinear + insertion can iterate flat ``nn.Linear`` modules. Does NOT load + weights — caller must invoke :func:`stream_load_weights`. + """ + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.integrations.deepspeed import HfDeepSpeedConfig # noqa: F401 + + # ``no_init_weights`` / ``no_tie_weights`` moved across transformers + # versions: newest (>=5.x) expose them under ``transformers.initialization``; + # older releases kept them in ``modeling_utils`` or the top-level package. + no_init_weights = None + no_tie_weights = None + for mod in ("transformers.initialization", "transformers.modeling_utils", "transformers"): + try: + m = __import__(mod, fromlist=["no_init_weights"]) + if hasattr(m, "no_init_weights"): + no_init_weights = m.no_init_weights + if hasattr(m, "no_tie_weights"): + no_tie_weights = m.no_tie_weights + except Exception: # noqa: BLE001 + continue + if no_init_weights is not None and no_tie_weights is not None: + break + if no_init_weights is None: + no_init_weights = nullcontext # type: ignore + if no_tie_weights is None: + no_tie_weights = nullcontext # type: ignore + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) + if attn_implementation != "default": + config._attn_implementation = attn_implementation + if use_cache is not None: + config.use_cache = use_cache + + resolved = _resolve_dtype(torch_dtype, config) + print_info( + f"{log_prefix} build empty ZeRO-3 model dtype={resolved} from config " + f"{getattr(config, 'model_type', '?')}" + ) + + # ``from_config`` triggers ``deepspeed.zero.Init`` automatically when + # ``HfTrainerDeepSpeedConfig`` is registered (i.e. when + # is_deepspeed_zero3_enabled() returns True). + with no_init_weights(), no_tie_weights(): + model = AutoModelForCausalLM.from_config( + config, + torch_dtype=resolved, + trust_remote_code=trust_remote_code, + ) + + # Linearize fused MoE experts BEFORE weight loading so the loader can + # write per-expert slices directly into the new Linear targets. + linearize_moe_experts_empty(model, dtype=resolved) + + return model + + +def _shards(model_path): + """Yield ``(shard_path, [keys])`` for every safetensors shard.""" + from safetensors import safe_open + + index_path = os.path.join(model_path, "model.safetensors.index.json") + if os.path.isfile(index_path): + with open(index_path, "r") as f: + weight_map = json.load(f)["weight_map"] + per_shard = {} + for key, shard in weight_map.items(): + per_shard.setdefault(shard, []).append(key) + for shard in sorted(per_shard): + yield os.path.join(model_path, shard), per_shard[shard] + return + + paths = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) + if not paths: + raise FileNotFoundError(f"No safetensors found under {model_path}") + for shard_path in paths: + with safe_open(shard_path, framework="pt") as r: + yield shard_path, list(r.keys()) + + +def _broadcast_into_target(src, target, *, is_buffer=False, key=None): + """Copy ``src`` (rank0 only, or None on other ranks) into ``target``. + + Handles three cases: + * ZeRO-3 sharded ``Parameter``: gather, rank0 writes, exit gather. + * Regular distributed ``Parameter`` / replicated buffer: rank0 stages, + then broadcast. + * Single-process: direct copy. + """ + dist_active = torch.distributed.is_available() and torch.distributed.is_initialized() + + if is_zero3_param(target): + with gathered_param_if_zero3(target, modifier_rank=0): + if _rank() == 0: + if src is None or src.shape != target.shape: + return False + target.data.copy_(src.to(device=target.device, dtype=target.dtype)) + return True + + # Regular tensor (parameter or buffer). + if dist_active: + if _rank() == 0: + if src is None or (not is_buffer and src.shape != target.shape): + return False + tmp = src.to(device=target.device, dtype=target.dtype).contiguous() + else: + tmp = torch.empty_like(target) + torch.distributed.broadcast(tmp, src=0) + target.data.copy_(tmp) + return True + + if src is None: + return False + target.data.copy_(src.to(device=target.device, dtype=target.dtype)) + return True + + +def stream_load_weights(model, model_path, log_prefix="[zero3]"): + """Stream a HF safetensors checkpoint into ``model``. + + Recognises fused MoE keys ``*.experts.gate_up_proj`` and + ``*.experts.down_proj`` and dispatches the per-expert slices into the + matching :class:`LinearizedMoeExperts` children. All other keys are + matched against ``model.named_parameters()`` / ``named_buffers()``. + + rank0 reads the bytes; ZeRO-3 sharded targets are filled inside + ``GatheredParameters(modifier_rank=0)``; replicated tensors are + broadcast. + """ + from safetensors import safe_open + + name_to_param = dict(model.named_parameters()) + name_to_buffer = dict(model.named_buffers()) + rank = _rank() + + loaded = 0 + skipped = 0 + seen_targets = set() + + for shard_path, keys in _shards(model_path): + with safe_open(shard_path, framework="pt") as reader: + for key in keys: + if key.endswith(".experts.gate_up_proj"): + base = key[: -len(".gate_up_proj")] + src = reader.get_tensor(key) if rank == 0 else None + n_exp = ( + int(src.shape[0]) + if src is not None + else _infer_num_experts(base, name_to_param) + ) + for i in range(n_exp): + gkey = f"{base}.{i}.gate_proj.weight" + ukey = f"{base}.{i}.up_proj.weight" + gtgt = name_to_param.get(gkey) + utgt = name_to_param.get(ukey) + if gtgt is None or utgt is None: + skipped += 2 + continue + gsrc = src[i].chunk(2, dim=-2)[0] if src is not None else None + usrc = src[i].chunk(2, dim=-2)[1] if src is not None else None + if _broadcast_into_target(gsrc, gtgt, key=gkey): + seen_targets.add(gkey) + loaded += 1 + else: + skipped += 1 + if _broadcast_into_target(usrc, utgt, key=ukey): + seen_targets.add(ukey) + loaded += 1 + else: + skipped += 1 + del src + elif key.endswith(".experts.down_proj"): + base = key[: -len(".down_proj")] + src = reader.get_tensor(key) if rank == 0 else None + n_exp = ( + int(src.shape[0]) + if src is not None + else _infer_num_experts(base, name_to_param) + ) + for i in range(n_exp): + dkey = f"{base}.{i}.down_proj.weight" + dtgt = name_to_param.get(dkey) + if dtgt is None: + skipped += 1 + continue + dsrc = src[i] if src is not None else None + if _broadcast_into_target(dsrc, dtgt, key=dkey): + seen_targets.add(dkey) + loaded += 1 + else: + skipped += 1 + del src + else: + tgt = name_to_param.get(key) + is_buf = False + if tgt is None: + tgt = name_to_buffer.get(key) + is_buf = tgt is not None + if tgt is None: + skipped += 1 + continue + src = reader.get_tensor(key) if rank == 0 else None + if _broadcast_into_target(src, tgt, is_buffer=is_buf, key=key): + seen_targets.add(key) + loaded += 1 + else: + skipped += 1 + del src + _cleanup() + print_info(f"{log_prefix} loaded shard {os.path.basename(shard_path)}") + + all_targets = set(name_to_param) | set(name_to_buffer) + missing = sorted(all_targets - seen_targets) + print_info( + f"{log_prefix} stream_load_weights done: " + f"loaded={loaded} skipped={skipped} missing={len(missing)}" + ) + if missing: + print_info(f"{log_prefix} first missing keys: {missing[:10]}") + + try: + model.tie_weights() + except Exception as e: # noqa: BLE001 + print_info(f"{log_prefix} tie_weights skipped: {e}") + + +def _infer_num_experts(base, name_to_param): + prefix = f"{base}." + ids = [] + for name in name_to_param: + if not name.startswith(prefix): + continue + first = name[len(prefix) :].split(".", 1)[0] + if first.isdigit(): + ids.append(int(first)) + return (max(ids) + 1) if ids else 0 + + +# --------------------------------------------------------------------------- +# Streaming PTQ-scale loader for QAT warm-start +# --------------------------------------------------------------------------- + + +_SCALE_SUFFIX_RULES = [ + # (suffix_in_ckpt, quantizer_attr_on_QuantLinear, sub_attr, layer_name_rewrite) + (".weight_zero_point", "weight_quantizer", "zero_point", None), + (".input_zero_point", "act_quantizer", "zero_point", None), + (".weight_scale", "weight_quantizer", "scale", None), + (".input_scale", "act_quantizer", "scale", None), + (".k_cache.scale", "qkv_quantizer", "scale", ".k_proj"), + (".v_cache.scale", "qkv_quantizer", "scale", ".v_proj"), +] +# Longest first to avoid '.scale' winning over '.k_cache.scale'. +_SCALE_SUFFIX_RULES.sort(key=lambda r: len(r[0]), reverse=True) + + +def _parse_scale_key(key): + for suffix, qname, sub, rewrite in _SCALE_SUFFIX_RULES: + if key.endswith(suffix): + base = key[: -len(suffix)] + return (base + rewrite if rewrite else base), qname, sub + return None + + +def _expand_scale_targets(layer_name, qname, sub, named_modules): + """Expand a checkpoint key that targets a fused MoE expert tensor into the + matching per-expert linears in the linearized model. For non-MoE keys + returns ``[(layer_name, qname, sub)]`` unchanged.""" + if layer_name in named_modules: + return [(layer_name, qname, sub)] + + # PTQ checkpoint may store scales for the fused expert matrix; map them + # to every per-expert Linear we expanded into. + if layer_name.endswith(".experts.gate_up_proj"): + base = layer_name[: -len(".gate_up_proj")] + return [ + (n, qname, sub) + for n in named_modules + if n.startswith(base + ".") and (n.endswith(".gate_proj") or n.endswith(".up_proj")) + ] + if layer_name.endswith(".experts.down_proj"): + base = layer_name[: -len(".down_proj")] + return [ + (n, qname, sub) + for n in named_modules + if n.startswith(base + ".") and n.endswith(".down_proj") + ] + return [] + + +def _copy_scale_into(src, target): + """rank0-driven copy of a scale-like tensor into a (possibly ZeRO-3) + Parameter, with shape coercion for scalar/per-tensor mismatch.""" + rank = _rank() + ok = True + with gathered_param_if_zero3(target, modifier_rank=0): + if rank == 0: + s = src + if s.numel() == target.numel(): + s = s.reshape(target.shape) + else: + try: + s = s.expand_as(target).contiguous() + except RuntimeError: + ok = False + if ok: + target.data.copy_(s.to(device=target.device, dtype=target.dtype)) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + device = ( + torch.device("cuda", torch.cuda.current_device()) + if torch.cuda.is_available() + else target.device + ) + flag = torch.tensor(int(ok), device=device) + torch.distributed.broadcast(flag, src=0) + ok = bool(flag.item()) + return ok + + +def stream_load_scales(model, ckpt_dir, log_prefix="[zero3]"): + """Read a PTQ "real" checkpoint and write its scale / zero_point / + kv-cache scale tensors into the matching ``QuantLinear`` quantizer + parameters of ``model``. + + Sets ``act_quantizer.init = True`` for every static activation + quantizer that successfully receives a scale, so the lazy-init pass is + skipped. + """ + from safetensors import safe_open + + # Resolve nested layout (some PTQ exporters nest under final_quant_checkpoint/). + if not glob.glob(os.path.join(ckpt_dir, "*.safetensors")): + nested = os.path.join(ckpt_dir, "final_quant_checkpoint") + if os.path.isdir(nested): + ckpt_dir = nested + + files = sorted(glob.glob(os.path.join(ckpt_dir, "*.safetensors"))) + if not files: + raise FileNotFoundError(f"No *.safetensors in {ckpt_dir}") + + # Lazy import to avoid circular dependency: this module is imported by + # angelslim.utils, which is imported very early. + from ..compressor.qat.modules.quantizer import QuantLinear + + named_modules = dict(model.named_modules()) + rank = _rank() + loaded = 0 + skipped = 0 + + for src_file in files: + with safe_open(src_file, framework="pt") as reader: + for key in reader.keys(): + parsed = _parse_scale_key(key) + if parsed is None: + continue + layer_name, qname, sub = parsed + targets = _expand_scale_targets(layer_name, qname, sub, named_modules) + if not targets: + skipped += 1 + continue + src = reader.get_tensor(key) if rank == 0 else None + for tgt_layer, tgt_qname, tgt_sub in targets: + module = named_modules.get(tgt_layer) + if not isinstance(module, QuantLinear): + skipped += 1 + continue + quantizer = getattr(module, tgt_qname, None) + if quantizer is None: + skipped += 1 + continue + target = getattr(quantizer, tgt_sub, None) + if not isinstance(target, torch.nn.Parameter): + skipped += 1 + continue + if _copy_scale_into(src, target): + loaded += 1 + if tgt_qname == "act_quantizer" and tgt_sub == "scale": + quantizer.init = True + else: + skipped += 1 + + print_info( + f"{log_prefix} stream_load_scales: loaded={loaded} skipped={skipped} from {ckpt_dir}" + ) + + +# --------------------------------------------------------------------------- +# Saving a sharded model via the model-specific save_func +# --------------------------------------------------------------------------- + + +def consolidated_state_dict(model): + """rank-0 CPU state_dict for a possibly ZeRO-3 sharded ``model``. + + Other ranks see an empty dict (matching the contract of HF/Trainer + save callbacks). Includes persistent buffers.""" + rank = _rank() + sd = {} + for name, param in model.named_parameters(): + with gathered_param_if_zero3(param): + if rank == 0: + sd[name] = param.detach().cpu().clone() + if rank == 0: + for module_name, module in model.named_modules(): + for buf_name, buf in module.named_buffers(recurse=False): + if buf is None or buf_name in module._non_persistent_buffers_set: + continue + full = f"{module_name}.{buf_name}" if module_name else buf_name + sd[full] = buf.detach().cpu().clone() + return sd + + +def save_via_model_save_func( + quant_model, + save_func, + save_target_dir, + prebuilt_state_dict=None, +): + """Invoke ``save_func.save(...)`` with the model's ``state_dict`` patched + to return a consolidated rank-0 dict. + + Parameters + ---------- + prebuilt_state_dict : dict | None + If provided (rank 0), use this dict directly and skip the per-param + gather+clone pass over ``model``. This is the **recommended** path + under ZeRO-3 because the caller (typically ``QAT.convert``) has + already produced rank0's full state_dict layer-by-layer with bounded + peak memory. Other ranks may pass ``None``. + + If ``None``, fall back to ``consolidated_state_dict(model)``, which + gathers every parameter once more — avoid this when ``model`` + already holds large materialised tensors on rank 0 (it would double + the peak). + + No-op (delegates straight to ``save_func.save``) when no parameters are + sharded. + """ + if not model_has_zero3_params(quant_model.model) and prebuilt_state_dict is None: + save_func.save(save_target_dir) + return + + rank = _rank() + if prebuilt_state_dict is not None: + sd = prebuilt_state_dict if rank == 0 else {} + else: + sd = consolidated_state_dict(quant_model.model) + + hf_model = quant_model.get_model() + original = hf_model.state_dict + + def _patched(*args, **kwargs): + return sd if rank == 0 else {} + + try: + hf_model.state_dict = _patched # type: ignore[method-assign] + if rank == 0: + save_func.save(save_target_dir) + finally: + hf_model.state_dict = original # type: ignore[method-assign] + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + +# --------------------------------------------------------------------------- +# DeepSpeed engine: tolerate scale parameters that appear in multiple param +# groups (the scale tensor itself is unique, but our optimizer construction +# may pick it up from multiple QuantLinear children of a shared MoE expert). +# --------------------------------------------------------------------------- + + +def patch_deepspeed_duplicate_check(): + """No-op DeepSpeed's ``_check_for_duplicates`` so QAT scale parameters + that share storage across modules don't crash optimizer init. + + Idempotent: only patches once per process. + """ + try: + from deepspeed.runtime.engine import DeepSpeedEngine + except Exception as exc: # noqa: BLE001 + print_info(f"[zero3] skip duplicate-check patch: {exc}") + return + if getattr(DeepSpeedEngine, "_angelslim_skip_dup_check", False): + return + + def _noop(self, basic_optimizer): # noqa: ARG001 + return + + DeepSpeedEngine._check_for_duplicates = _noop + DeepSpeedEngine._angelslim_skip_dup_check = True + print_info("[zero3] patched DeepSpeed _check_for_duplicates (idempotent).") diff --git a/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml b/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml new file mode 100644 index 00000000..06558ffd --- /dev/null +++ b/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml @@ -0,0 +1,38 @@ +# Global configuration of pipeline +global: + save_path: ./output_ptq_hy3 + +# Simplified Configuration for LLM compression +model: + name: HYV3MoE + model_path: tencent/Hy3-preview + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + kv_cache: "per-tensor" + ignore_layers: + - "lm_head" + - "model.embed_tokens" + - "gate.weight" + cpu_convert: true + save_name: "fp8" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4/sharegpt_gpt4_256.jsonl + max_seq_length: 2048 + num_samples: 512 + batch_size: 1 diff --git a/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json b/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json new file mode 100644 index 00000000..e513dd7b --- /dev/null +++ b/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json @@ -0,0 +1,45 @@ +{ + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": "auto", + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none" + }, + "offload_param": { + "device": "none" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml new file mode 100644 index 00000000..d110e90d --- /dev/null +++ b/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml @@ -0,0 +1,72 @@ +global: + save_path: ./output_hy3_zero3 + +model: + name: HYV3MoE + model_path: tencent/Hy3-preview + trust_remote_code: true + torch_dtype: auto + device_map: None + low_cpu_mem_usage: true + use_cache: false + +compression: + name: QAT + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] + QAT: + hf_dataset: null + from_ptq_ckpt: ./output_ptq_hy3/hunyuanv3_a20b_fp8_static + training_mode: end2end + dist_mode: hf + save_format: real + do_train: true + loss_type: cakld + loss_topk: null + kd_temperature: 1.0 + kd_alpha: 0.5 + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: false + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + bf16: true + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-6 + lr_scheduler_type: cosine + num_train_epochs: 1 + max_steps: 3 + save_strategy: "no" + warmup_ratio: 0.0 + max_grad_norm: 1.0 + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: true + deepspeed: configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4/sharegpt_gpt4_256.jsonl + max_seq_length: 2048 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json b/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json new file mode 100644 index 00000000..e513dd7b --- /dev/null +++ b/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json @@ -0,0 +1,45 @@ +{ + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": "auto", + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none" + }, + "offload_param": { + "device": "none" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml new file mode 100644 index 00000000..f2940943 --- /dev/null +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml @@ -0,0 +1,86 @@ +global: + save_path: ./output + +model: + name: Qwen + # Local directory or HuggingFace repo id of the base Qwen3-30B-A3B model. + model_path: Qwen/Qwen3-30B-A3B + trust_remote_code: true + torch_dtype: auto + # ZeRO-3 wants HF to NOT place the model with ``device_map``. Pass + # ``None`` here; the Engine normalises the string "None"/"distributed". + device_map: None + low_cpu_mem_usage: true + use_cache: false + +compression: + name: QAT + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] + QAT: + # Leave ``hf_dataset`` unset → the trainer falls back to the local + # TextDataset defined in the ``dataset:`` section below. + hf_dataset: null + # REQUIRED under ZeRO-3: warm-start scales from a previous PTQ "real" + # checkpoint. Produce it with e.g. + # python tools/run.py -c configs/qwen3/ptq/fp8_static/qwen3-a3b_fp8_static.yaml + from_ptq_ckpt: ./output_ptq_30b/qwen3-a3b_fp8_static + training_mode: end2end + dist_mode: hf + save_format: real + do_train: true + # KD + LM loss combination: loss_type picks the KD variant, + # lm_loss_weight / kd_loss_weight control the outer mix. + # loss_type choices: kl | rkl | mse | cakld | kd | kl_top_K | r_kl_top_K + loss_type: cakld + loss_topk: null + kd_temperature: 1.0 + kd_alpha: 0.5 + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: false + # Init values used when weight data is not accessible (e.g. under + # ZeRO-3) and for quantizer parameters missing from the PTQ + # checkpoint. + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + bf16: true + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-6 + lr_scheduler_type: cosine + num_train_epochs: 1 + max_steps: 3 + save_strategy: "no" + warmup_ratio: 0.0 + max_grad_norm: 1.0 + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: true + deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml new file mode 100644 index 00000000..bb0822b2 --- /dev/null +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml @@ -0,0 +1,73 @@ +global: + save_path: ./output_4b_zero3 + +model: + name: Qwen + model_path: Qwen/Qwen3-4B + trust_remote_code: true + torch_dtype: auto + device_map: None + low_cpu_mem_usage: true + use_cache: false + +compression: + name: QAT + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens"] + QAT: + hf_dataset: null + from_ptq_ckpt: ./output_ptq/qwen3-4b_fp8_static + training_mode: end2end + dist_mode: hf + save_format: real + do_train: true + loss_type: cakld + loss_topk: null + kd_temperature: 1.0 + kd_alpha: 0.5 + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: false + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + bf16: true + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-6 + lr_scheduler_type: cosine + num_train_epochs: 1 + max_steps: 5 + save_strategy: "no" + warmup_ratio: 0.0 + max_grad_norm: 1.0 + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: true + deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 512 + num_samples: 16 + batch_size: 1 + is_sft_data: true diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index 2048f22d..3878481c 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -13,5 +13,6 @@ awq gptq fp8_lepto qat +qat_zero3 daq ::: diff --git a/docs/source/features/quantization/qat_zero3.md b/docs/source/features/quantization/qat_zero3.md new file mode 100644 index 00000000..49c8e2ff --- /dev/null +++ b/docs/source/features/quantization/qat_zero3.md @@ -0,0 +1,323 @@ +# QAT + DeepSpeed ZeRO-3 + +## 概述 + +本文档描述 AngelSlim QAT 模块的 **DeepSpeed ZeRO-3** 支持。核心动机:当基础模型体量超过单卡显存(例如 Qwen3-30B-A3B 这类 MoE 模型)时,需要把模型参数、梯度、优化器状态都切片到多张 GPU 上,HuggingFace + DeepSpeed 的 ZeRO-3 是一套成熟的方案,但直接用在 QAT 流程里会遇到以下问题: + +1. HuggingFace `from_pretrained` 在 ZeRO-3 下会让每个 rank 先在 CPU 上加载**完整** state_dict,再 partition,峰值内存约 `world_size × model_size`,大模型必然 OOM。 +2. QAT 会把原始 `nn.Linear` 替换成 `QuantLinear`,又会把 fused MoE expert 拆成 per-expert `nn.Linear`。这些替换操作如果在 ZeRO-3 sharded 参数上执行,会同时触发切片 / 合并 / 重分发,逻辑极易出错、峰值不可控。 +3. QAT 的 activation scale 默认通过 forward 校准(lazy init),在 ZeRO-3 下无法运行;weight scale 的初始化依赖读取完整权重,也不可行。 +4. `convert()` 把 `QuantLinear` 转为 `QDQModule`,以及 `save_format=real` 导出压缩权重时,需要在多卡之间合并参数并保持 CPU 内存在 **rank0 一份**。 + +针对上述问题,本实现的关键设计如下: + +- **每个 rank 独立构造一个空模型**(不读磁盘权重),通过 `deepspeed.zero.Init` 立即 partition,峰值内存仅 `model_size / world_size`。 +- **fused MoE 拆解发生在空模型阶段**(duck-typing 识别 `gate_up_proj / down_proj / num_experts / hidden_dim / intermediate_dim / act_fn`),新建的 per-expert `nn.Linear` 被 `deepspeed.zero.Init` 立即切片,**不需要从旧的 fused tensor 拷贝数据**。 +- **权重和 scale 都通过 safetensors 流式加载**:rank0 按文件顺序读取,其它 rank 通过 `GatheredParameters(modifier_rank=0)` 接收数据。fused MoE 的 `.experts.gate_up_proj` / `.experts.down_proj` key 会被自动切片写入每个 per-expert target。 +- **ZeRO-3 下 scale 强制从外部 PTQ checkpoint 加载**(`from_ptq_ckpt`),跳过 forward 校准;未命中的 scale 按 `weight_scale_init_value` / `activation_scale_init_value` 填充。 +- **`convert / save` 只在 rank0 构造 `QDQModule` 并直接写入合并 state_dict**,其它 rank 只参与 NCCL gather 的 collective,不持有 CPU 数据。 + +整体流程下,非 ZeRO-3 路径的行为与 main 分支完全保持一致——所有 ZeRO-3 逻辑都收敛在一个新文件 `angelslim/utils/zero3_io.py` 中,其他文件只做少量薄调用。 + +## 架构设计 + +### 模块目录结构 + +``` +angelslim/ +├── utils/ +│ └── zero3_io.py # ALL ZeRO-3 helpers: detection, +│ # gather/scatter, empty-model build, +│ # streaming weight/scale loaders, +│ # consolidated save, optimizer patch. +├── compressor/qat/ +│ ├── qat.py # ZeRO-3 branch for convert() / save() +│ ├── modules/quantizer.py # init scales from init_value when +│ # weight data is not accessible; dtype +│ # alignment for DeepSpeed autocast. +│ ├── plugins/learnable_scale.py # stream_load_scales + skip lazy_init +│ │ # + gathered quant_inplace +│ └── trainers/end2end_trainer.py # lm_loss + kd_loss composition with +│ # cakld support + per-component logging +├── data/text_dataset.py # supervise ONLY the last assistant turn +├── models/base_model.py # from_pretrained → ZeRO-3 path +├── engine.py # normalise ``device_map`` string +└── utils/ + ├── config_parser.py # QATTrainingConfig new fields + ├── utils.py # set_op_by_name handles string-indexed + └── __init__.py # re-export zero3_io helpers +``` + +### 执行流程(ZeRO-3 路径) + +``` +tools/run.py + └── _prewarm_hf_deepspeed_config() # register HfTrainerDeepSpeedConfig + └── Engine.prepare_model() + └── BaseLLMModel.from_pretrained() # ZeRO-3 branch + ├── zero3_empty_model_from_pretrained() + │ ├── AutoModelForCausalLM.from_config(...) + │ │ # triggers deepspeed.zero.Init for every Parameter + │ └── linearize_moe_experts_empty() + │ # fused Qwen3MoeExperts → empty LinearizedMoeExperts + └── stream_load_weights() # rank0 reads safetensors + └── QAT.__init__() + ├── init_ptq() → Qwen.replace_moe() # no-op: already linearised + └── register LearnableScalePlugin(from_ptq_ckpt_dir=...) + └── LearnableScalePlugin.before_train() + ├── replace nn.Linear with QuantLinear + │ └── Quantizer allocates scale Parameters using init_value + │ (no dependency on weight data) + └── stream_load_scales(from_ptq_ckpt) # fill weight/act/kv scales + └── End2EndTrainer.prepare_trainer() + ├── _init_optimizer() with id-deduped scale/LWC params + ├── patch_deepspeed_duplicate_check() # scales may be tied + └── HF Trainer.train() + # student + teacher forward → lm_loss + kd_loss composition + └── QAT.convert() + QAT.save() + # rank0-only QDQModule + consolidated state_dict → single rank write +``` + +## 使用方法 + +### 前置条件 + +1. 安装依赖:`deepspeed`、`safetensors`、`compressed-tensors`(可选,用于读取导出的 fp8 checkpoint)。 +2. 硬件:支持 NCCL 的多 GPU 节点;ZeRO-3 路径要求 `torchrun --nproc_per_node=N`。 + +### 完整两阶段流程 + +#### 阶段 1:PTQ 校准生成初始 scale + +ZeRO-3 QAT 启动时**不再**跑 forward 校准,因此必须先产出一个带 scale 的 PTQ checkpoint。使用现有 PTQ 配置(单卡即可): + +```bash +python tools/run.py \ + -c configs/qwen3/ptq/fp8_static/qwen3-a3b_fp8_static.yaml \ + --model-path /path/to/Qwen3-30B-A3B \ + --save-path ./output_ptq_30b +``` + +产出的 `./output_ptq_30b/qwen3-a3b_fp8_static/model-*.safetensors` 中包含 `.weight_scale` 和 `.input_scale`。 + +#### 阶段 2:ZeRO-3 QAT + +```bash +bash scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh +``` + +或直接: + +```bash +torchrun --nproc_per_node=8 tools/run.py \ + -c configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml +``` + +训练完成后,压缩权重会保存到 `./output/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3/final_quant_checkpoint/`。 + +### 最小 4B 烟囱测试 + +快速验证流程(2 张卡、5 步训练): + +```bash +# Step 1: PTQ +python tools/run.py \ + -c configs/qwen3/ptq/fp8_static/qwen3-4b_fp8_static.yaml \ + --model-path Qwen/Qwen3-4B \ + --save-path ./output_ptq + +# Step 2: QAT +bash scripts/qat/run_qat_for_qwen_4b_zero3.sh +``` + +## 配置说明 + +### ZeRO-3 QAT 新增 / 修改字段(`compression.QAT`) + +| 字段 | 类型 | 默认值 | 描述 | +|------|------|--------|------| +| `from_ptq_ckpt` | str / null | `null` | PTQ `save_format="real"` 产出目录。ZeRO-3 下**必填**,否则 `before_train` 会报错。目录可指向最顶层(自动识别嵌套的 `final_quant_checkpoint/`)。 | +| `lm_loss_weight` | float | `1.0` | HF CausalLM CE loss 的权重。为 0 则不计算 lm loss,也不出现在日志中。 | +| `kd_loss_weight` | float | `0.0` | KD loss 的权重(使用 `loss_type` 选定的 KD 变体)。为 0 则不计算 KD,也不启动 teacher forward。 | + +注:`lm_loss_weight` / `kd_loss_weight` 任意一个 > 0 即可参与训练;两者都为 0 会在 `compute_loss` 中抛错。 + +### KD 变体(`loss_type`) + +| `loss_type` | 描述 | +|-------------|------| +| `kl` | `KL(teacher || student)`(forward KL),per-valid-token 平均 | +| `rkl` | `KL(student || teacher)`(reverse / backward KL),per-valid-token 平均 | +| `mse` | student 与 teacher logits 的 MSE(per-valid-token 平均) | +| `cakld` | Confidence-Aware KL Distillation:按 teacher 在 label 上的概率做 token-wise 的 `fkl` / `rkl` 混合,`conf * rkl + (1 - conf) * fkl` | +| `kd` | 经典 temperature KD:`T² * KL(soft_student || soft_teacher)`,保留兼容 | +| `kl_top_K` / `r_kl_top_K` | top-K token 上的 forward / reverse KL。`K` 可写在 `loss_type` 字符串里(例如 `kl_top_1000`),或通过 `loss_topk` 字段指定 | +| `origin` | 纯 HF CE loss(等价于 `kd_loss_weight = 0`) | + +选择 `loss_type = cakld` 时的核心公式(参考 `_compute_kd_components`): + +```python +# 仅在 labels != -100 的 token 上计算 +forward_kl = KL(log_softmax(student), softmax(teacher)) +backward_kl = KL(log_softmax(teacher), softmax(student)) +conf = softmax(teacher).gather(-1, label) # teacher 对目标 token 的置信度 +cakld = (conf * backward_kl + (1 - conf) * forward_kl).mean() +``` + +### 训练日志 + +`QATSeq2SeqTrainer.log()` 会自动把下列指标注入 HF Trainer 的标准日志字典(所以 wandb / console / tqdm 都能看到),仅当对应权重 > 0 时才会出现: + +| 指标 | 含义 | +|------|------| +| `lm_loss` | HF CausalLM CE loss(仅对 assistant 回复位置计算,见下) | +| `kd/` | 当前选定的 KD 主 loss(`cakld` / `kl` / ...) | +| `kd/forward_kl` | 诊断用:`KL(teacher || student)`,始终在 `kd_loss_weight > 0` 时打印 | +| `kd/backward_kl` | 诊断用:`KL(student || teacher)` | +| `total_loss` | `lm_loss_weight * lm_loss + kd_loss_weight * kd/` | + +示例输出(30B-A3B,3 步训练): + +``` +{'loss': 2.08, 'grad_norm': 43.1, 'learning_rate': 1e-6, + 'lm_loss': 1.23, 'kd/cakld': 0.0075, 'kd/forward_kl': 0.0075, 'kd/backward_kl': 0.0075, + 'total_loss': 1.24, 'epoch': 0.03} +``` + +### Dataset:仅监督最后一个 assistant 回复 + +对于 JSONL 格式的 SFT 数据(`messages` / `conversations` / `input+output` 三种 schema),`TextDataset._load_jsonl_data` 现在只对**最后一个 assistant 回复**位置计算 loss: + +- 拼接 prompt(对话中最后一个 assistant 之前的所有 turn)并通过 `apply_chat_template(..., add_generation_prompt=True)` tokenize,得到 `prompt_len`。 +- 拼接完整对话(含最后 assistant)tokenize 得到 `input_ids`,`labels = input_ids.clone()`。 +- 把 `labels[:, :prompt_len]` 和 padding 位置都置为 `-100`,HF CausalLM loss 会自动忽略。 + +这与 HF CausalLM 内部的 shift 行为一致(`shift_logits[..., :-1]` 对齐 `shift_labels[..., 1:]`),**不需要**手动 `roll`。 + +### DeepSpeed 配置 + +参考 `configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json`。关键项: + +```json +{ + "bf16": {"enabled": "auto"}, + "zero_optimization": { + "stage": 3, + "stage3_gather_16bit_weights_on_model_save": true, + "overlap_comm": true, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto" +} +``` + +`hf_args` 中需要设 `bf16: true`(或 `fp16: true`)显式开启混合精度,否则 DeepSpeed 的 `zero3_linear_wrap` 会默认用 fp16 autocast,而模型权重是 bf16,导致 dtype 失配。 + +### Quantizer 初始值 + +当 ZeRO-3 启用(或 weight 是 ZeRO-3 sharded / meta / 0 numel 的任一情况)时,`Quantizer._init_quant_params` 不再依赖权重数值,而是直接按 shape 创建 `nn.Parameter`: + +| 配置项 | 默认值 | 描述 | +|------|--------|------| +| `weight_scale_init_value` | `1.0` | Weight quantizer scale 的初始值,在 `from_ptq_ckpt` 未命中时保留 | +| `activation_scale_init_value` | `1.0` | Activation quantizer scale 的初始值 | + +典型场景下 PTQ 产出的 weight scale 约为 `max(|W|) / 448 ≈ 1e-3`,建议在 yaml 里把 init value 设为 `0.1` 左右作为保底(`from_ptq_ckpt` 命中时这些值会被覆盖)。 + +### 示例 yaml 关键片段 + +```yaml +compression: + name: QAT + quantization: + name: fp8_static + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] + QAT: + hf_dataset: null + from_ptq_ckpt: ./output_ptq_30b/qwen3-a3b_fp8_static + training_mode: end2end + dist_mode: hf + save_format: real + loss_type: cakld + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + bf16: true + per_device_train_batch_size: 1 + learning_rate: 1.0e-6 + gradient_checkpointing: true + deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json +``` + +完整示例: +- `configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml` +- `configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml` + +## 核心实现要点 + +### `zero3_io.py` 中的主要 API + +| 函数 / 类 | 作用 | +|-----------|------| +| `is_deepspeed_zero3_enabled()` | 通过 HF 的 `HfTrainerDeepSpeedConfig` 弱引用判断是否已注册 ZeRO-3 配置 | +| `is_zero3_param(p)` | 判断 Parameter 是否带有 `ds_id / ds_status / ds_numel / ds_tensor` 元数据 | +| `gathered_param_if_zero3(p, modifier_rank=None)` | 上下文管理器;非 ZeRO-3 参数为 no-op | +| `LinearizedMoeExperts` | 通用空 per-expert `nn.Linear` 容器,`forward` 与 HF fused 等价 | +| `linearize_moe_experts_empty(model)` | duck-typing 扫描并原地替换,在 `deepspeed.zero.Init` 内构造以直接 partition | +| `zero3_empty_model_from_pretrained(model_path, ...)` | `no_init_weights` + `from_config` 构造空模型 + 自动拆 MoE | +| `stream_load_weights(model, model_path)` | 流式灌权,支持 fused MoE key 的 per-expert 切片分发 | +| `stream_load_scales(model, ckpt_dir)` | 流式灌 scale,支持 `.weight_scale / .input_scale / .k_cache.scale / .v_cache.scale` | +| `save_via_model_save_func(quant_model, save_func, path, prebuilt_state_dict)` | 只在 rank0 调用原 save_func,`state_dict()` 被 patch 返回合并后的字典 | +| `patch_deepspeed_duplicate_check()` | 置空 `DeepSpeedEngine._check_for_duplicates`,允许 tied scale 参数 | + +### `Quantizer` 的变更 + +- 新增 `weight_shape` 构造参数:`QuantLinear.__init__` 传入 `(out_features, in_features)`,使得 Quantizer 在不访问权重数据的情况下也能计算 scale 的形状。 +- 新增 `weight_scale_init_value` / `activation_scale_init_value` 配置项;`_init_quant_params` / `_init_lwc_params` 在 `_needs_external_weight_init(x)` 为真时使用。 +- `QuantLinear.forward` 末尾的 `F.linear` 现在包在 `torch.amp.autocast(device_type="cuda", enabled=False)` 中,并在调用前把 `input.dtype` 对齐到 `weight.dtype`,以避免 DeepSpeed `zero3_linear_wrap` 的 autocast 把 bf16 input 回转成 fp16。 +- `fake_quant` 末尾把 `out` cast 回 `x.dtype`,防止 `bf16 * fp32 = fp32` 的 dtype 泄漏。 + +### 优化器去重 + +`End2EndTrainer._init_optimizer` 使用 `_unique_named_params(...)` 根据 `id()` 去重收集 trainable 的 scale / LWC 参数,避免同一 Parameter 被多次加入 param group(在 MoE expert 共享 tensor 的场景下会触发)。配合 `patch_deepspeed_duplicate_check()` 即可通过 DeepSpeed 的安全检查。 + +### `convert` + `save` 的内存控制 + +`QAT.convert` 在 ZeRO-3 路径下: + +1. 对每个 `QuantLinear`:**所有 rank** 都进入 `gathered_param_if_zero3` 拿到完整 weight(NCCL collective 保持对称),但**只 rank0** 保留 CPU clone。 +2. rank0 在临时 `QDQModule` 内部跑一次 fp8/int 量化,把 `weight / weight_scale / input_scale / bias` 取出塞进 `self._rank0_state_dict`,随后丢弃临时模块。 +3. **不修改模型结构**:保持 `QuantLinear`,使得第二轮扫描(收集非 QuantLinear 参数,如 embed、lm_head、layernorm、MoE router gate)在所有 rank 上 `named_parameters` 顺序一致,collective gather 不会死锁。 +4. `QAT.save` 把 `_rank0_state_dict` 透传给 `save_via_model_save_func`,后者 patch `hf_model.state_dict`,**只 rank0** 调原 `save_func.save(...)`。 + +rank>0 convert 阶段 CPU 峰值 ≈ 一层的完整 weight(几十 MB 到 GB 量级);rank0 峰值 ≈ 累积的合并 state_dict(完整模型大小)。 + +## 已验证场景 + +| 场景 | 模型 | 硬件 | 结果 | +|------|------|------|------| +| Dense ZeRO-3 QAT | Qwen3-4B | 2×H20 | ✓ PTQ→QAT→save 打通,产物能被 transformers 加载 | +| MoE ZeRO-3 QAT | Qwen3-30B-A3B(48 层 × 128 experts)| 8×H20 | ✓ `stream_load_scales` 命中 37248 个 scale,训练 loss 稳定,输出 31 GB fp8 checkpoint | +| KD + LM loss 组合 | 同上 | 同上 | ✓ `lm_loss / kd/cakld / kd/forward_kl / kd/backward_kl / total_loss` 按权重打印 | +| 最后 assistant 仅监督 | TextDataset(jsonl) | - | ✓ 首个 valid label idx 落在 `<\|im_start\|>assistant\n` 之后 | +| 非 ZeRO-3 回归 | Qwen3-4B 单卡 PTQ | 1×H20 | ✓ 行为与 main 一致,无回归 | + diff --git a/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh b/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh new file mode 100755 index 00000000..e5b2d1c3 --- /dev/null +++ b/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +export PYTORCH_ALLOC_CONF="expandable_segments:True" + +NPROC=${NPROC:-8} +CONFIG=${CONFIG:-configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml} + +torchrun --nproc_per_node=${NPROC} \ + tools/run.py \ + -c "${CONFIG}" diff --git a/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh b/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh new file mode 100755 index 00000000..1b244127 --- /dev/null +++ b/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +export PYTORCH_ALLOC_CONF="expandable_segments:True" + +NPROC=${NPROC:-8} +CONFIG=${CONFIG:-configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml} + +torchrun --nproc_per_node=${NPROC} \ + tools/run.py \ + -c "${CONFIG}" diff --git a/scripts/qat/run_qat_for_qwen_4b_zero3.sh b/scripts/qat/run_qat_for_qwen_4b_zero3.sh new file mode 100755 index 00000000..45181701 --- /dev/null +++ b/scripts/qat/run_qat_for_qwen_4b_zero3.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Smoke-test launcher: Qwen3-4B with DeepSpeed ZeRO-3. Matches the full +# 30B-A3B flow but on a model small enough to iterate quickly. + +export PYTORCH_ALLOC_CONF="expandable_segments:True" + +NPROC=${NPROC:-2} +CONFIG=${CONFIG:-configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml} + +torchrun --nproc_per_node=${NPROC} \ + tools/run.py \ + -c "${CONFIG}" diff --git a/tools/run.py b/tools/run.py index c4135199..88d84451 100644 --- a/tools/run.py +++ b/tools/run.py @@ -112,6 +112,7 @@ def multi_nodes_run(config): shuffle=dataset_config.shuffle, inference_settings=dataset_config.inference_settings, use_audio_in_video=model_config.use_audio_in_video, + is_sft_data=dataset_config.is_sft_data, ) # Step 6: Initialize compressor @@ -267,6 +268,32 @@ def weight_only_run(config): ) +def _prewarm_hf_deepspeed_config(config): + """Pre-construct ``Seq2SeqTrainingArguments`` so HF's + ``HfTrainerDeepSpeedConfig`` weak-ref is registered BEFORE + ``from_pretrained`` runs. That is what flips + ``is_deepspeed_zero3_enabled()`` to True and makes our + ``BaseLLMModel.from_pretrained`` take the ZeRO-3 path. + + Returns the constructed TrainingArguments (kept alive via the caller's + local variable) or None if not applicable. + """ + compress_cfg = getattr(config, "compression_config", None) + qat_cfg = getattr(compress_cfg, "QAT", None) if compress_cfg is not None else None + hf_args = getattr(qat_cfg, "hf_args", None) if qat_cfg is not None else None + if not hf_args or not hf_args.get("deepspeed"): + return None + + from transformers import Seq2SeqTrainingArguments + + trainer_args = Seq2SeqTrainingArguments( + output_dir=config.global_config.save_path, + **hf_args, + ) + print_info("[DeepSpeed pre-warm] HfTrainerDeepSpeedConfig registered before model load.") + return trainer_args + + def run(config): """ Run the LLM compression process based on the provided configuration. @@ -295,6 +322,12 @@ def run(config): weight_only_run(config) return + # QAT + DeepSpeed: register HfTrainerDeepSpeedConfig BEFORE loading the + # model so ``from_pretrained`` takes the ZeRO-3 path. No-op otherwise. + # The returned object must stay alive until after the model is built + # because HF's weak-ref mechanism drops the config otherwise. + _hf_ds_args = _prewarm_hf_deepspeed_config(config) + # Step 2: Execute complete pipeline slim_engine = Engine() @@ -312,6 +345,9 @@ def run(config): attn_implementation=model_config.attn_implementation, deploy_backend=global_config.deploy_backend, ) + # Safe to release now: the model is built and any deepspeed.zero.Init + # effects have already happened on all parameters. + del _hf_ds_args # Step 4: Prepare data (optional custom dataloader) if compress_config.need_dataset: @@ -327,6 +363,7 @@ def run(config): use_audio_in_video=model_config.use_audio_in_video, model_name=model_config.name, quantization_config=compress_config.quantization, + is_sft_data=dataset_config.is_sft_data, ) # Step 5: Initialize compressor