diff --git a/.gitignore b/.gitignore
index 00e37b2b..d8fcb3d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,6 +15,8 @@ eval/
*_ckpt*/
output/
outputs/
+output*/
+logs*/
outs/
wandb/
tools/results/
diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py
index bc81e998..4eb627e4 100644
--- a/angelslim/compressor/qat/modules/quantizer.py
+++ b/angelslim/compressor/qat/modules/quantizer.py
@@ -18,6 +18,8 @@
import torch.nn as nn
import torch.nn.functional as F
+from ....utils import is_deepspeed_zero3_enabled, is_zero3_param
+
FP8_E4M3_QMIN = -448
FP8_E4M3_QMAX = 448
@@ -49,10 +51,30 @@ def _parse_bits_and_dtype(qtype_str):
class Quantizer(nn.Module):
- def __init__(self, config, quant_info, x=None, is_act=False, resume=False, num_heads=-1):
+ def __init__(
+ self,
+ config,
+ quant_info,
+ x=None,
+ is_act=False,
+ resume=False,
+ num_heads=-1,
+ weight_shape=None,
+ ):
super().__init__()
self.is_act = is_act
self.num_heads = num_heads
+ # ``weight_shape`` lets the caller pre-declare the (out_features,
+ # in_features) of the parent Linear so we can size weight-side
+ # quantizer Parameters without ever touching the (possibly ZeRO-3
+ # sharded) weight tensor.
+ self.weight_shape = (
+ (int(weight_shape[0]), int(weight_shape[1])) if weight_shape is not None else None
+ )
+ # Configurable initial values used when ZeRO-3 is active and we
+ # cannot depend on the weight data.
+ self.weight_scale_init_value = float(config.get("weight_scale_init_value", 1.0))
+ self.activation_scale_init_value = float(config.get("activation_scale_init_value", 1.0))
info = quant_info.quant_algo_info["w"]
self.group_size = quant_info.quant_algo_info.get("w_group_size", -1)
rewrite_conf = config.get("weight", {})
@@ -117,8 +139,21 @@ def _init_quant_params(self, x):
self.scale = self.zero_point = None
if self.resume:
self.init = True
- zp = torch.empty(1) if not self.is_sym else None
- self._set_quant_parameters(torch.empty(1), zp)
+ init_val = self.activation_scale_init_value
+ scale = torch.full((1,), init_val, dtype=torch.float32)
+ zp = torch.zeros(1, dtype=torch.float32) if not self.is_sym else None
+ self._set_quant_parameters(scale, zp)
+ return
+
+ # Weight-side path. If we cannot use ``x`` (ZeRO-3 sharded,
+ # meta, or simply not provided), allocate Parameters by shape
+ # and ``weight_scale_init_value``.
+ if self._needs_external_weight_init(x):
+ shape = self._weight_scale_shape_from_meta()
+ init_val = self.weight_scale_init_value
+ scale = torch.full(shape, init_val, dtype=torch.float32)
+ zp = torch.zeros(shape, dtype=torch.float32) if not self.is_sym else None
+ self._set_quant_parameters(scale, zp)
return
if self.is_sym:
@@ -131,6 +166,52 @@ def _init_quant_params(self, x):
)
self._set_quant_parameters(scale, zp.round())
+ def _needs_external_weight_init(self, x):
+ """True when weight-side init must skip data-dependent computation
+ and instead allocate Parameters from shape + init_value.
+
+ Triggered by:
+ * DeepSpeed ZeRO-3 active (HF integration registered)
+ * ``x`` is a ZeRO-3 sharded Parameter
+ * ``x`` is None / on meta device / empty
+ """
+ if is_deepspeed_zero3_enabled():
+ return True
+ if x is None:
+ return True
+ if is_zero3_param(x):
+ return True
+ if hasattr(x, "device") and x.device.type == "meta":
+ return True
+ if hasattr(x, "numel") and x.numel() == 0:
+ return True
+ return False
+
+ def _weight_2d_shape(self):
+ """Resolve (out_features, in_features) for the underlying Linear.
+ Callers must have passed ``weight_shape`` via ``QuantLinear``."""
+ if self.weight_shape is not None:
+ return self.weight_shape
+ raise RuntimeError(
+ "Quantizer needs ``weight_shape`` to size weight scale without a "
+ "concrete tensor (set in QuantLinear.__init__)."
+ )
+
+ def _weight_scale_shape_from_meta(self):
+ out_dim, in_dim = self._weight_2d_shape()
+ if self.granularity == "per-channel":
+ return (out_dim, 1)
+ if self.granularity == "per-group":
+ if not self.group_size or self.group_size <= 0:
+ raise ValueError("per-group quantization requires positive group_size.")
+ if in_dim % self.group_size != 0:
+ raise ValueError(
+ f"dim 1 ({in_dim}) not divisible by group_size ({self.group_size})"
+ )
+ return (out_dim, in_dim // self.group_size)
+ # per-tensor and any reduce-to-scalar variant
+ return (1,)
+
def _init_lwc_params(self, x, config):
lwc_cfg = config.get("lwc", {})
if isinstance(lwc_cfg, dict):
@@ -141,11 +222,18 @@ def _init_lwc_params(self, x, config):
self.lwc_init_value = 4.0
if self.lwc:
- if x.dim() != 2:
- x_for_shape = x.flatten(1)
+ # Resolve (out_dim, in_dim) without depending on ``x`` data.
+ if self._needs_external_weight_init(x):
+ out_dim, in_dim = self._weight_2d_shape()
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
- x_for_shape = x
- out_dim, in_dim = x_for_shape.shape
+ if x.dim() != 2:
+ x_for_shape = x.flatten(1)
+ else:
+ x_for_shape = x
+ out_dim, in_dim = x_for_shape.shape
+ device = x.device
+
if self.granularity == "per-group":
if not self.group_size or self.group_size <= 0:
raise ValueError("per-group quantization requires positive group_size.")
@@ -157,9 +245,7 @@ def _init_lwc_params(self, x, config):
else:
dim1 = 1
- init = (
- torch.ones((dim1, 1), device=x.device, dtype=torch.float32) * self.lwc_init_value
- )
+ init = torch.ones((dim1, 1), device=device, dtype=torch.float32) * self.lwc_init_value
self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True)
self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True)
self.sigmoid = nn.Sigmoid()
@@ -473,7 +559,14 @@ def fake_quant(self, x):
None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax)
)
scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x)
- return self._fake_quant_with_params(x, scale, round_zero_point)
+ out = self._fake_quant_with_params(x, scale, round_zero_point)
+ # Scale is kept in fp32 for numerical stability, but multiplying by
+ # a bf16/fp16 activation upcasts the result. Cast back to the input
+ # dtype so downstream F.linear / DeepSpeed autocast wrappers see a
+ # consistent dtype.
+ if out.dtype != x.dtype:
+ out = out.to(x.dtype)
+ return out
def forward(self, x: torch.Tensor):
if self.bits >= 16:
@@ -516,8 +609,18 @@ def __init__(
self.register_parameter("bias", org_module.bias)
self.use_weight_quant = use_weight_quant
self.use_act_quant = use_act_quant
+ # Under ZeRO-3 the weight Parameter ``org_module.weight`` may be a
+ # zero-numel shard. Pass an explicit (out, in) shape so the weight
+ # quantizer can size its scale Parameter from the Linear shape
+ # rather than inspecting the (possibly sharded) tensor.
+ weight_shape = (org_module.out_features, org_module.in_features)
if self.use_weight_quant:
- self.weight_quantizer = Quantizer(config, quant_info, x=org_module.weight)
+ self.weight_quantizer = Quantizer(
+ config,
+ quant_info,
+ x=org_module.weight,
+ weight_shape=weight_shape,
+ )
if self.use_act_quant:
self.act_quantizer = Quantizer(config, quant_info, is_act=True, resume=resume)
@@ -531,13 +634,16 @@ def __init__(
)
def forward(self, input: torch.Tensor):
- if input.shape[0] == 0:
- return self.fwd_func(input, self.weight, self.bias)
-
weight = self.weight_quantizer(self.weight) if self.use_weight_quant else self.weight
if self.use_act_quant:
input = self.act_quantizer(input)
- output = self.fwd_func(input, weight, self.bias)
+ # Defensive dtype alignment: upstream (DeepSpeed ZeRO-3 / HF
+ # autocast) may have cast ``input`` to fp16 even though we run in
+ # bf16. Align to the (fake-quantised) weight dtype so F.linear
+ # stays consistent.
+ output = self.fwd_func(
+ input.to(self.weight.dtype), weight.to(self.weight.dtype), self.bias
+ )
if self.use_qkv_quant:
output = self.qkv_quantizer(output)
return output
diff --git a/angelslim/compressor/qat/plugins/distill_loss.py b/angelslim/compressor/qat/plugins/distill_loss.py
new file mode 100644
index 00000000..4024b3e4
--- /dev/null
+++ b/angelslim/compressor/qat/plugins/distill_loss.py
@@ -0,0 +1,117 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn.functional as F
+
+
+class DistillLoss:
+ def __init__(self, loss_type="kl", loss_topk=None, kd_temperature=1.0):
+ self.loss_type = str(loss_type).lower()
+ self.loss_topk = loss_topk
+ self.kd_temperature = float(kd_temperature)
+
+ @staticmethod
+ def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor:
+ """Per-token KL(tgt || src) (shape [N])."""
+ return F.kl_div(log_p_src, p_tgt, reduction="none").sum(dim=-1)
+
+ def compute(self, student_logits, teacher_logits, labels):
+ """Return per-token KD losses computed only on valid labels.
+
+ The returned dict always contains ``loss``, ``forward_kl`` and
+ ``backward_kl`` so callers can log diagnostics for every KD variant.
+ """
+ flat_mask = (labels != -100).reshape(-1)
+ if flat_mask.sum() == 0:
+ zero = student_logits.new_zeros(())
+ return {"loss": zero, "forward_kl": zero, "backward_kl": zero}
+
+ # Flatten to [N, V] and keep only valid tokens.
+ s_flat = student_logits.flatten(0, -2)[flat_mask]
+ t_flat = teacher_logits.flatten(0, -2)[flat_mask]
+ valid_labels = labels.reshape(-1)[flat_mask]
+
+ # Diagnostic KL (always computed at T=1 on valid tokens).
+ s_logp = F.log_softmax(s_flat, dim=-1)
+ t_logp = F.log_softmax(t_flat, dim=-1)
+ s_p = s_logp.exp()
+ t_p = t_logp.exp()
+ forward_kl = self._kl_per_token(s_logp, t_p).mean()
+ backward_kl = self._kl_per_token(t_logp, s_p).mean()
+
+ # Main KD loss according to loss_type.
+ if self.loss_type == "kl":
+ kd = forward_kl
+ elif self.loss_type == "rkl":
+ kd = backward_kl
+ elif self.loss_type == "mse":
+ kd = F.mse_loss(s_flat, t_flat)
+ elif self.loss_type == "kd":
+ # Legacy "kd": temperature-scaled forward KL. The outer trainer
+ # combines this value with the LM loss by configured weights.
+ temperature = max(self.kd_temperature, 1e-6)
+ kd = self._kl_per_token(
+ F.log_softmax(s_flat / temperature, dim=-1),
+ F.softmax(t_flat / temperature, dim=-1),
+ ).mean() * (temperature * temperature)
+ elif self.loss_type == "cakld":
+ # Contextual Asymmetric KL-divergence: per-token mixing of
+ # forward / reverse KL by teacher's confidence on the label.
+ per_tok_fkl = self._kl_per_token(s_logp, t_p)
+ per_tok_bkl = self._kl_per_token(t_logp, s_p)
+ conf = torch.gather(t_p, dim=-1, index=valid_labels.unsqueeze(-1)).squeeze(-1)
+ kd = (conf * per_tok_bkl + (1.0 - conf) * per_tok_fkl).mean()
+ elif self._is_reverse_topk_loss(self.loss_type):
+ topk = self._resolve_topk(self.loss_type, s_flat.size(-1))
+ top_s, idx = s_flat.topk(topk, dim=-1, sorted=False)
+ top_t = t_flat.gather(-1, idx)
+ kd = self._kl_per_token(
+ F.log_softmax(top_t, dim=-1),
+ F.softmax(top_s, dim=-1),
+ ).mean()
+ elif self._is_forward_topk_loss(self.loss_type):
+ topk = self._resolve_topk(self.loss_type, t_flat.size(-1))
+ top_t, idx = t_flat.topk(topk, dim=-1, sorted=False)
+ top_s = s_flat.gather(-1, idx)
+ kd = self._kl_per_token(
+ F.log_softmax(top_s, dim=-1),
+ F.softmax(top_t, dim=-1),
+ ).mean()
+ else:
+ raise ValueError(
+ f"Unsupported QAT kd loss_type: {self.loss_type}. "
+ "Valid: kl, rkl, mse, kd, cakld, kl_top[_K], r_kl_top[_K]."
+ )
+
+ return {"loss": kd, "forward_kl": forward_kl, "backward_kl": backward_kl}
+
+ @staticmethod
+ def _is_forward_topk_loss(loss_type):
+ return loss_type.startswith("kl_top")
+
+ @staticmethod
+ def _is_reverse_topk_loss(loss_type):
+ return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top")
+
+ def _resolve_topk(self, loss_type, vocab_size):
+ topk = self.loss_topk
+ if topk is None and "_top_" in loss_type:
+ topk = int(loss_type.rsplit("_", 1)[-1])
+ if topk is None:
+ topk = 1000
+ topk = int(topk)
+ if topk <= 0:
+ raise ValueError(f"loss_topk must be positive, got: {topk}")
+ return min(topk, vocab_size)
diff --git a/angelslim/compressor/qat/plugins/learnable_scale.py b/angelslim/compressor/qat/plugins/learnable_scale.py
index 8f01514f..0f9ef3cc 100644
--- a/angelslim/compressor/qat/plugins/learnable_scale.py
+++ b/angelslim/compressor/qat/plugins/learnable_scale.py
@@ -15,7 +15,13 @@
import torch
from tqdm import tqdm
-from ....utils import print_info, set_op_by_name
+from ....utils import (
+ gathered_params_if_zero3,
+ is_deepspeed_zero3_enabled,
+ print_info,
+ set_op_by_name,
+ stream_load_scales,
+)
from ..modules.quantizer import QuantLinear
from .base_plugin import BasePlugin
from .plugin_manager import PluginManager
@@ -29,11 +35,22 @@
@PluginManager.plugin("learnable_scale")
class LearnableScalePlugin(BasePlugin):
- def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, **kwargs):
+ def __init__(
+ self,
+ quant_info=None,
+ ignore_layers=None,
+ resume_ckpt_dir=None,
+ from_ptq_ckpt_dir=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.quant_info = quant_info
self.ignore_layers = ignore_layers
self.resume_ckpt_dir = resume_ckpt_dir
+ # Optional warm-start from a PTQ "real" checkpoint (only scales are
+ # read; base weights stay as loaded by from_pretrained). Required
+ # under DeepSpeed ZeRO-3.
+ self.from_ptq_ckpt_dir = from_ptq_ckpt_dir
self.use_weight_quant = self.config.get("use_weight_quant", False)
self.use_activation_quant = self.config.get("use_activation_quant", False)
self.fp8_attn = self.config.get("fp8_attn", False)
@@ -47,9 +64,23 @@ def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, **
self.learn_norm = learnable_cfg.get("norm", False)
def before_train(self, **kwargs):
+ zero3 = is_deepspeed_zero3_enabled()
+ if zero3 and not self.from_ptq_ckpt_dir:
+ raise ValueError(
+ "DeepSpeed ZeRO-3 QAT requires `compression.QAT.from_ptq_ckpt` "
+ "to warm-start scales (lazy_init via forward is impossible "
+ "on sharded weights)."
+ )
+
# Retrieve KV head count from model config for per-head quantization
model_config = getattr(self.quant_model.model, "config", None)
num_kv_heads = getattr(model_config, "num_key_value_heads", -1)
+ # Pre-allocate ``act_quantizer.scale`` as a Parameter whenever we
+ # plan to fill it from a checkpoint (full resume OR PTQ warm-start
+ # OR ZeRO-3 — where lazy_init is impossible).
+ act_preallocate = (
+ self.resume_ckpt_dir is not None or self.from_ptq_ckpt_dir is not None or zero3
+ )
for name, module in self.quant_model.model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(ig in name for ig in self.ignore_layers):
@@ -67,7 +98,7 @@ def before_train(self, **kwargs):
self.quant_info,
self.use_weight_quant,
self.use_activation_quant,
- resume=self.resume_ckpt_dir is not None,
+ resume=act_preallocate,
qkv_config=qkv_cfg,
)
set_op_by_name(self.quant_model.model, name, q_linear)
@@ -78,10 +109,18 @@ def before_train(self, **kwargs):
print_info(self.quant_model.model)
+ # Warm-start scales from a previous PTQ "real" checkpoint. Only
+ # quantizer Parameters are touched; base Linear weights are NOT
+ # overwritten.
+ if self.from_ptq_ckpt_dir is not None:
+ stream_load_scales(self.quant_model.model, self.from_ptq_ckpt_dir)
+
if (
self.use_activation_quant
and not q_linear.act_quantizer.dynamic
and self.resume_ckpt_dir is None
+ and not zero3
+ and self.from_ptq_ckpt_dir is None
):
self._lazy_init(**kwargs)
@@ -284,5 +323,13 @@ def _get_qkv_config_for_layer(name, quant_config):
@torch.no_grad()
def quant_inplace(model):
for _, module in model.named_modules():
- if isinstance(module, QuantLinear):
+ if not isinstance(module, QuantLinear):
+ continue
+ # Gather the weight together with all weight_quantizer Parameters
+ # (scale / zero_point / optional LWC clip factors) so the
+ # fake-quant runs on the full materialised tensor under ZeRO-3.
+ params = [module.weight]
+ if hasattr(module, "weight_quantizer"):
+ params.extend(module.weight_quantizer.parameters(recurse=True))
+ with gathered_params_if_zero3(params, modifier_rank=None):
module.weight.data = module.weight_quantizer(module.weight.data)
diff --git a/angelslim/compressor/qat/qat.py b/angelslim/compressor/qat/qat.py
index a4274595..c5c524cd 100644
--- a/angelslim/compressor/qat/qat.py
+++ b/angelslim/compressor/qat/qat.py
@@ -17,7 +17,13 @@
import torch
from safetensors.torch import save_file
-from ...utils import print_info, set_op_by_name
+from ...utils import (
+ gathered_param_if_zero3,
+ model_has_zero3_params,
+ print_info,
+ save_via_model_save_func,
+ set_op_by_name,
+)
from ..compressor_factory import CompressorFactory
from ..quant.modules.helper_layer import QDQModule
from .modules.quantizer import QuantLinear
@@ -38,6 +44,11 @@ def __init__(self, model, slim_config=None):
self.quant_model.init_ptq(slim_config)
self.quant_info = self.quant_model.quant_config
self.plugin_manager = PluginManager()
+ # When set, ``save`` will use this rank-0-only state_dict instead of
+ # walking the model again. Populated by ``convert`` under ZeRO-3 to
+ # avoid keeping a full CPU copy of every layer's QDQModule on every
+ # rank.
+ self._rank0_state_dict = None
self._init_plugins()
self._init_trainer()
@@ -57,6 +68,7 @@ def _init_plugins(self):
quant_info=self.quant_info,
ignore_layers=self.config["compress_config"].quantization.ignore_layers,
resume_ckpt_dir=self.config["compress_config"].QAT.resume_ckpt_dir,
+ from_ptq_ckpt_dir=self.config["compress_config"].QAT.from_ptq_ckpt,
config=self.plugin_config.get("quant_config", {}),
quant_model=self.quant_model,
)
@@ -72,46 +84,169 @@ def _init_trainer(self):
def run(self, dataloader):
self.trainer.run(dataloader)
+ @staticmethod
+ def _gather_clone(tensor):
+ """Detach + CPU-clone a tensor, gathering if it is a ZeRO-3 shard.
+
+ WARNING: every rank gets a full CPU copy. Only safe for SMALL tensors
+ (e.g. scale Parameters). For large weights under ZeRO-3 use
+ ``_rank0_gather_clone`` instead.
+ """
+ if tensor is None:
+ return None
+ with gathered_param_if_zero3(tensor):
+ return tensor.detach().cpu().clone()
+
+ @staticmethod
+ def _sym_gather_clone(tensor):
+ """Symmetric gather-and-clone: every rank gets a full CPU copy.
+
+ Collective timing is symmetric across ranks (minimising NCCL
+ stalls). Caller is responsible for dropping the clone on rank>0
+ immediately to avoid keeping 'world_size' copies alive.
+ """
+ if tensor is None:
+ return None
+ with gathered_param_if_zero3(tensor):
+ return tensor.detach().cpu().clone()
+
def convert(self):
if self.save_fmt not in ("real", "real_and_kvcache"):
return
- print_info("Start QAT convert: replacing QuantLinear with QDQModule...")
+ zero3 = model_has_zero3_params(self.quant_model.model)
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
quant_algo = self.quant_info.quant_algo
+ if not zero3:
+ # ----- single-GPU / non-ZeRO-3 path: original behaviour -----
+ print_info("Start QAT convert: replacing QuantLinear with QDQModule...")
+ for name, module in [
+ (n, m)
+ for n, m in self.quant_model.model.named_modules()
+ if isinstance(m, QuantLinear)
+ ]:
+ weight_scale = (
+ module.weight_quantizer.scale.data.clone()
+ if hasattr(module, "weight_quantizer")
+ else None
+ )
+ input_scale = None
+ if module.use_act_quant and hasattr(module, "act_quantizer"):
+ aq = module.act_quantizer
+ if getattr(aq, "scale", None) is not None:
+ input_scale = aq.scale.data.clone()
+ qdq_module = QDQModule(
+ quant_algo=quant_algo,
+ weight=module.weight,
+ weight_scale=weight_scale,
+ bias=module.bias,
+ group_size=(
+ module.weight_quantizer.group_size
+ if hasattr(module, "weight_quantizer")
+ and hasattr(module.weight_quantizer, "group_size")
+ else 128
+ ),
+ input_scale=input_scale,
+ )
+ set_op_by_name(self.quant_model.model, name, qdq_module)
+ return
+
+ # ----- ZeRO-3 path: every rank gathers + clones per layer (fast,
+ # NCCL-symmetric), but only rank0 keeps the data by feeding it into
+ # ``_rank0_state_dict``. rank>0 drops the clone immediately so peak
+ # CPU remains bounded by ~one layer's worth of tensors per rank.
+ # Model structure is NOT modified — we stream straight into the
+ # state_dict and let ``save_via_model_save_func`` patch
+ # ``state_dict()`` for the underlying save_func.
+ print_info(
+ f"[rank{rank}] Start QAT convert (ZeRO-3 mode: stream rank0 "
+ "state_dict, keep model structure intact)..."
+ )
+ self._rank0_state_dict = {} if rank == 0 else {}
+
quant_linear_modules = [
- (name, module)
- for name, module in self.quant_model.model.named_modules()
- if isinstance(module, QuantLinear)
+ (n, m) for n, m in self.quant_model.model.named_modules() if isinstance(m, QuantLinear)
]
+ consumed_prefixes = set()
for name, module in quant_linear_modules:
+
+ # Symmetric gather: all ranks clone (memcpy, fast) so NCCL
+ # timing stays tight. This is ~world_size× transient CPU RAM
+ # for JUST this one layer; we free right after the rank0
+ # branch completes.
+ weight = self._sym_gather_clone(module.weight)
+ bias = self._sym_gather_clone(getattr(module, "bias", None))
weight_scale = None
if hasattr(module, "weight_quantizer"):
- weight_scale = module.weight_quantizer.scale.data.clone()
-
+ weight_scale = self._sym_gather_clone(module.weight_quantizer.scale)
input_scale = None
if module.use_act_quant and hasattr(module, "act_quantizer"):
- act_quantizer = module.act_quantizer
- if hasattr(act_quantizer, "scale") and act_quantizer.scale is not None:
- input_scale = act_quantizer.scale.data.clone()
+ aq = module.act_quantizer
+ if getattr(aq, "scale", None) is not None:
+ input_scale = self._sym_gather_clone(aq.scale)
+ consumed_prefixes.add(name)
+
+ if rank != 0:
+ # Drop the clone immediately; next iteration will overwrite
+ # these locals anyway but be explicit for clarity.
+ del weight, bias, weight_scale, input_scale
+ continue
+
+ # rank0 only: run the fp8/int quantize path via a throwaway
+ # QDQModule, then move its params into the consolidated dict
+ # and discard the module.
qdq_module = QDQModule(
quant_algo=quant_algo,
- weight=module.weight,
+ weight=weight,
weight_scale=weight_scale,
- bias=module.bias,
+ bias=bias,
group_size=(
module.weight_quantizer.group_size
- if hasattr(module.weight_quantizer, "group_size")
+ if hasattr(module, "weight_quantizer")
+ and hasattr(module.weight_quantizer, "group_size")
else 128
),
input_scale=input_scale,
)
- set_op_by_name(self.quant_model.model, name, qdq_module)
+ for sub_name, p in qdq_module.named_parameters(recurse=False):
+ self._rank0_state_dict[f"{name}.{sub_name}"] = p.detach().cpu()
+ for sub_name, b in qdq_module.named_buffers(recurse=False):
+ if sub_name in qdq_module._non_persistent_buffers_set:
+ continue
+ self._rank0_state_dict[f"{name}.{sub_name}"] = b.detach().cpu()
+ del qdq_module, weight, bias, weight_scale, input_scale
+
+ # Second pass: params/buffers that are NOT inside a QuantLinear
+ # (embeddings, lm_head, layernorms, MoE router gate, ...). The
+ # collective order MUST be identical across ranks, so this loop
+ # runs on every rank; only rank0 keeps the data.
+ for pname, param in self.quant_model.model.named_parameters():
+ if any(pname.startswith(p + ".") for p in consumed_prefixes):
+ continue
+ with gathered_param_if_zero3(param):
+ if rank == 0:
+ self._rank0_state_dict[pname] = param.detach().cpu().clone()
+
+ if rank == 0:
+ for module_name, mod in self.quant_model.model.named_modules():
+ for buf_name, buf in mod.named_buffers(recurse=False):
+ if buf is None or buf_name in mod._non_persistent_buffers_set:
+ continue
+ full_key = f"{module_name}.{buf_name}" if module_name else buf_name
+ if full_key in self._rank0_state_dict:
+ continue
+ self._rank0_state_dict[full_key] = buf.detach().cpu().clone()
+ print_info(
+ f"[zero3] convert done: rank0 state_dict has "
+ f"{len(self._rank0_state_dict)} tensors."
+ )
def _save_kv_cache_scales(self, save_path: str):
"""Extract and save KV cache scales to a safetensors file."""
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
kv_scales = {}
for name, module in self.quant_model.model.named_modules():
if not isinstance(module, QuantLinear):
@@ -126,8 +261,12 @@ def _save_kv_cache_scales(self, save_path: str):
else:
continue
scale_key = f"{cache_name}.scale"
- kv_scales[scale_key] = module.qkv_quantizer.scale.data.clone().float().cpu()
+ scale_tensor = self._gather_clone(module.qkv_quantizer.scale)
+ if scale_tensor is not None and rank == 0:
+ kv_scales[scale_key] = scale_tensor.float()
+ if rank != 0:
+ return
os.makedirs(save_path, exist_ok=True)
out_file = os.path.join(save_path, "kv_cache_scales.safetensors")
save_file(kv_scales, out_file)
@@ -146,7 +285,12 @@ def save(self, save_path: str):
# "real": save real-quant model via model-specific save function
elif self.save_fmt == "real":
save_func = self.quant_model.get_save_func()(self.quant_model)
- save_func.save(os.path.join(save_path, "final_quant_checkpoint"))
+ save_via_model_save_func(
+ self.quant_model,
+ save_func,
+ os.path.join(save_path, "final_quant_checkpoint"),
+ prebuilt_state_dict=self._rank0_state_dict,
+ )
# "save_kvcache_only": only export KV cache scales (kv_cache_scales.safetensors)
elif self.save_fmt == "save_kvcache_only":
@@ -155,7 +299,12 @@ def save(self, save_path: str):
# "real_and_kvcache": save real-quant model AND KV cache scales
elif self.save_fmt == "real_and_kvcache":
save_func = self.quant_model.get_save_func()(self.quant_model)
- save_func.save(os.path.join(save_path, "final_quant_checkpoint"))
+ save_via_model_save_func(
+ self.quant_model,
+ save_func,
+ os.path.join(save_path, "final_quant_checkpoint"),
+ prebuilt_state_dict=self._rank0_state_dict,
+ )
self._save_kv_cache_scales(os.path.join(save_path, "final_quant_checkpoint"))
else:
diff --git a/angelslim/compressor/qat/trainers/end2end_trainer.py b/angelslim/compressor/qat/trainers/end2end_trainer.py
index fb5daf6d..57eac608 100644
--- a/angelslim/compressor/qat/trainers/end2end_trainer.py
+++ b/angelslim/compressor/qat/trainers/end2end_trainer.py
@@ -13,16 +13,33 @@
# limitations under the License.
import torch
-import torch.nn.functional as F
from datasets import load_dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from ....data.qat_dataset import QATDataset
-from ....utils import print_info
+from ....utils import patch_deepspeed_duplicate_check, print_info
+from ..plugins.distill_loss import DistillLoss
from ..plugins.learnable_scale import set_quant_state
from .trainer_factory import TrainerFactory
+def _unique_named_params(model, predicate):
+ """Collect parameters matching ``predicate`` with id-based de-duplication.
+
+ Some QAT setups share a single scale Parameter across multiple
+ QuantLinear views (e.g. MoE experts built from a shared tensor). HF /
+ DeepSpeed optimizer init rejects duplicates, so we de-dup by ``id``.
+ """
+ seen = set()
+ result = []
+ for name, param in model.named_parameters():
+ if id(param) in seen or not predicate(name, param):
+ continue
+ seen.add(id(param))
+ result.append(param)
+ return result
+
+
class QATSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, *args, loss_config=None, quant_config=None, **kwargs):
super().__init__(*args, **kwargs)
@@ -31,93 +48,102 @@ def __init__(self, *args, loss_config=None, quant_config=None, **kwargs):
self.loss_type = str(loss_config.get("loss_type", "origin")).lower()
self.loss_topk = loss_config.get("loss_topk")
self.kd_temperature = float(loss_config.get("kd_temperature", 1.0))
+ # ``kd_alpha`` kept for backward compat but IGNORED when
+ # ``lm_loss_weight`` / ``kd_loss_weight`` are the (new) source of
+ # truth.
self.kd_alpha = float(loss_config.get("kd_alpha", 0.5))
+ self.lm_loss_weight = float(loss_config.get("lm_loss_weight", 1.0))
+ self.kd_loss_weight = float(loss_config.get("kd_loss_weight", 0.0))
+ self.distill_loss = DistillLoss(
+ loss_type=self.loss_type,
+ loss_topk=self.loss_topk,
+ kd_temperature=self.kd_temperature,
+ )
self.use_weight_quant = quant_config.get("use_weight_quant", False)
self.use_activation_quant = quant_config.get("use_activation_quant", False)
self.use_qkv_quant = quant_config.get("use_qkv_quant", False)
+ # Running metric aggregator keyed by logger mode.
+ from collections import defaultdict
+
+ self._qat_metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
+
+ def _record(self, name, value):
+ if value is None:
+ return
+ mode = "train" if self.model.training else "eval"
+ v = value.detach().float() if isinstance(value, torch.Tensor) else float(value)
+ if isinstance(v, torch.Tensor):
+ self._qat_metrics[mode][name].append(v.item())
+ else:
+ self._qat_metrics[mode][name].append(v)
+
+ # ------------------------------------------------------------------
+ # compute_loss
+ # ------------------------------------------------------------------
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ labels = inputs.get("labels", None)
+ lm_on = self.lm_loss_weight > 0.0
+ kd_on = self.kd_loss_weight > 0.0
+
+ # Back-compat: ``loss_type="origin"`` means "pure HF CE loss" → the
+ # classic SFT path with no distillation. Honour it even when the
+ # user forgot to set kd_loss_weight=0.
if self.loss_type == "origin":
- return super().compute_loss(
- model,
- inputs,
- return_outputs=return_outputs,
- num_items_in_batch=num_items_in_batch,
- )
+ kd_on = False
+
+ if not lm_on and not kd_on:
+ raise ValueError("Both lm_loss_weight and kd_loss_weight are 0 — nothing to optimise.")
- teacher_logits = self.get_ori_outputs(model, inputs).logits
+ # Student forward — always needed.
+ # HF CausalLM loss is computed when ``labels`` is present in inputs.
student_inputs = dict(inputs)
- if self.loss_type != "kd":
+ if not lm_on:
+ # Still need labels for flat_mask; pop from student kwargs to
+ # skip HF's internal CE and save some compute.
student_inputs.pop("labels", None)
outputs = model(**student_inputs)
- student_logits = outputs.logits
- if self.loss_type == "kl":
- loss = F.kl_div(
- F.log_softmax(student_logits.flatten(0, -2), dim=-1),
- F.softmax(teacher_logits.flatten(0, -2), dim=-1),
- reduction="batchmean",
- )
- elif self.loss_type == "rkl":
- loss = F.kl_div(
- F.log_softmax(teacher_logits.flatten(0, -2), dim=-1),
- F.softmax(student_logits.flatten(0, -2), dim=-1),
- reduction="batchmean",
+ lm_loss = outputs.loss if lm_on and getattr(outputs, "loss", None) is not None else None
+ if lm_on and lm_loss is None:
+ raise ValueError(
+ "lm_loss_weight > 0 but model did not return a loss — "
+ "check that ``labels`` is set in the batch."
)
- elif self.loss_type == "mse":
- loss = F.mse_loss(student_logits, teacher_logits)
- elif self.loss_type == "kd":
- if getattr(outputs, "loss", None) is None:
- raise ValueError("loss_type='kd' requires labels to compute CE loss.")
- temperature = max(self.kd_temperature, 1e-6)
- alpha = self.kd_alpha
- distill_loss = F.kl_div(
- F.log_softmax(student_logits.flatten(0, -2) / temperature, dim=-1),
- F.softmax(teacher_logits.flatten(0, -2) / temperature, dim=-1),
- reduction="batchmean",
- )
- loss = outputs.loss * (1 - alpha) + distill_loss * (alpha * temperature * temperature)
- elif self._is_reverse_topk_loss(self.loss_type):
- topk = self._resolve_topk(self.loss_type, student_logits.size(-1))
- top_student_logits, indices = student_logits.topk(topk, dim=-1, sorted=False)
- top_teacher_logits = teacher_logits.gather(-1, indices)
- loss = F.kl_div(
- F.log_softmax(top_teacher_logits.flatten(0, -2), dim=-1),
- F.softmax(top_student_logits.flatten(0, -2), dim=-1),
- reduction="batchmean",
- )
- elif self._is_forward_topk_loss(self.loss_type):
- topk = self._resolve_topk(self.loss_type, teacher_logits.size(-1))
- top_teacher_logits, indices = teacher_logits.topk(topk, dim=-1, sorted=False)
- top_student_logits = student_logits.gather(-1, indices)
- loss = F.kl_div(
- F.log_softmax(top_student_logits.flatten(0, -2), dim=-1),
- F.softmax(top_teacher_logits.flatten(0, -2), dim=-1),
- reduction="batchmean",
- )
- else:
- raise ValueError(f"Unsupported QAT loss_type: {self.loss_type}")
- return (loss, outputs) if return_outputs else loss
+ kd_info = None
+ if kd_on:
+ if labels is None:
+ raise ValueError("kd_loss_weight > 0 requires ``labels`` in the batch.")
+ teacher_logits = self.get_ori_outputs(model, inputs).logits
+ kd_info = self.distill_loss.compute(
+ outputs.logits,
+ teacher_logits,
+ labels,
+ )
- @staticmethod
- def _is_forward_topk_loss(loss_type):
- return loss_type.startswith("kl_top")
+ # Combine.
+ total = outputs.logits.new_zeros(())
+ if lm_loss is not None:
+ total = total + self.lm_loss_weight * lm_loss
+ if kd_info is not None:
+ total = total + self.kd_loss_weight * kd_info["loss"]
- @staticmethod
- def _is_reverse_topk_loss(loss_type):
- return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top")
+ # Logging: record every component whose weight is > 0, plus the
+ # always-informative forward/backward KL diagnostics when kd is on.
+ if lm_on and lm_loss is not None:
+ self._record("lm_loss", lm_loss)
+ if kd_on and kd_info is not None:
+ self._record(f"kd/{self.loss_type}", kd_info["loss"])
+ # Diagnostic KL(L/R) for any kd variant. Useful to monitor
+ # teacher-student disagreement independent of the combined
+ # objective.
+ self._record("kd/forward_kl", kd_info["forward_kl"])
+ self._record("kd/backward_kl", kd_info["backward_kl"])
+ self._record("total_loss", total)
- def _resolve_topk(self, loss_type, vocab_size):
- topk = self.loss_topk
- if topk is None and "_top_" in loss_type:
- topk = int(loss_type.rsplit("_", 1)[-1])
- if topk is None:
- topk = 1000
- topk = int(topk)
- if topk <= 0:
- raise ValueError(f"loss_topk must be positive, got: {topk}")
- return min(topk, vocab_size)
+ return (total, outputs) if return_outputs else total
@torch.no_grad()
def get_ori_outputs(self, model, inputs):
@@ -136,6 +162,31 @@ def get_ori_outputs(self, model, inputs):
)
return outputs
+ def log(self, logs, start_time=None, *args, **kwargs):
+ """Inject running QAT loss components (lm_loss / kd/... / kd/forward_kl
+ / kd/backward_kl / total_loss) into HuggingFace Trainer's log dict.
+
+ Each value is averaged across the steps accumulated since the
+ previous ``log`` call, then the accumulator is cleared — matching
+ HF Trainer's behaviour for its built-in ``loss`` key.
+ """
+ mode = "train" if self.model.training else "eval"
+ bucket = self._qat_metrics.get(mode, {})
+ if bucket:
+ for key, vals in bucket.items():
+ if not vals:
+ continue
+ avg = sum(vals) / len(vals)
+ out_key = key if mode == "train" else f"eval_{key}"
+ logs[out_key] = float(avg)
+ bucket.clear()
+
+ # Forward to HF Trainer's log. Signature differs across versions.
+ try:
+ return super().log(logs, start_time, *args, **kwargs)
+ except TypeError:
+ return super().log(logs)
+
@TrainerFactory.register("end2end")
class End2EndTrainer:
@@ -157,6 +208,8 @@ def __init__(self, quant_model, config, plugin_manager):
"loss_topk": config["compress_config"].QAT.loss_topk,
"kd_temperature": config["compress_config"].QAT.kd_temperature,
"kd_alpha": config["compress_config"].QAT.kd_alpha,
+ "lm_loss_weight": config["compress_config"].QAT.lm_loss_weight,
+ "kd_loss_weight": config["compress_config"].QAT.kd_loss_weight,
}
self.quant_config = {
"use_weight_quant": config["compress_config"]
@@ -173,13 +226,13 @@ def __init__(self, quant_model, config, plugin_manager):
def _init_optimizer(self):
lr = float(self.config["compress_config"].QAT.hf_args.get("learning_rate", 1e-5))
wd = float(self.config["compress_config"].QAT.hf_args.get("weight_decay", 0))
+ scale_params = _unique_named_params(
+ self.quant_model.model,
+ lambda n, p: p.requires_grad and ("scale" in n or "zero_point" in n),
+ )
params = [
{
- "params": [
- p
- for n, p in self.quant_model.model.named_parameters()
- if "scale" in n or "zero_point" in n
- ],
+ "params": scale_params,
"weight_decay": wd,
"lr": lr,
}
@@ -191,6 +244,7 @@ def _init_optimizer(self):
.get("lwc", {})
.get("enable_lwc", False)
)
+ lwc_param_count = 0
if enable_lwc:
lwc_lr = float(
self.config["compress_config"]
@@ -198,30 +252,40 @@ def _init_optimizer(self):
.get("lwc", {})
.get("lwc_lr", 1e-1)
)
- lwc_params = [
- {
- "params": [
- p
- for n, p in self.quant_model.model.named_parameters()
- if "clip_factor_w_max" in n or "clip_factor_w_min" in n
- ],
- "weight_decay": wd,
- "lr": lwc_lr,
- }
- ]
- params.extend(lwc_params)
+ lwc_params = _unique_named_params(
+ self.quant_model.model,
+ lambda n, p: p.requires_grad
+ and ("clip_factor_w_max" in n or "clip_factor_w_min" in n),
+ )
+ lwc_param_count = len(lwc_params)
+ params.append({"params": lwc_params, "weight_decay": wd, "lr": lwc_lr})
+
+ if not any(group["params"] for group in params):
+ raise ValueError("QAT optimizer has no trainable parameters.")
self.optimizer = torch.optim.AdamW(params)
if enable_lwc:
- print_info(f"Init optimizer with learnable lr={lr} lwc_lr={lwc_lr} weight_decay={wd}")
+ print_info(
+ f"Init optimizer with {len(scale_params)} scale params, "
+ f"{lwc_param_count} lwc params, lr={lr} lwc_lr={lwc_lr} weight_decay={wd}"
+ )
else:
- print_info(f"Init optimizer with learnable lr={lr} weight_decay={wd}")
+ print_info(
+ f"Init optimizer with {len(scale_params)} scale params, "
+ f"lr={lr} weight_decay={wd}"
+ )
def prepare_trainer(self):
if self.training_mode == "blockwise":
return
if self.training_mode == "end2end" and self.dist_mode == "hf":
self._init_optimizer()
+ # When DeepSpeed is used, neutralize its duplicate-parameter
+ # check: it rejects param-groups that share tensors, which our
+ # scale/zero_point setup can legally have (shared tensors
+ # across views). Idempotent and a no-op if deepspeed is absent.
+ if self.config["compress_config"].QAT.hf_args.get("deepspeed") is not None:
+ patch_deepspeed_duplicate_check()
self.external_trainer = QATSeq2SeqTrainer(
model=self.quant_model.model,
processing_class=self.quant_model.tokenizer,
diff --git a/angelslim/data/dataloader.py b/angelslim/data/dataloader.py
index a3dcc0a6..72488ce1 100644
--- a/angelslim/data/dataloader.py
+++ b/angelslim/data/dataloader.py
@@ -44,6 +44,7 @@ def create_data_loader(
use_audio_in_video: bool = False,
model_name: str = None,
quantization_config: str = None,
+ is_sft_data: bool = False,
) -> DataLoader:
"""
Create appropriate DataLoader based on data source
@@ -85,6 +86,7 @@ def create_data_loader(
max_length=max_length,
data_path=data_source,
num_samples=num_samples,
+ is_sft_data=is_sft_data,
)
elif data_type == "MultiModalDataset":
dataset = MultiModalDataset(
diff --git a/angelslim/data/text_dataset.py b/angelslim/data/text_dataset.py
index 510b3a05..96e5a874 100644
--- a/angelslim/data/text_dataset.py
+++ b/angelslim/data/text_dataset.py
@@ -34,8 +34,10 @@ def __init__(
device: str = "cpu",
max_length: int = 4096,
num_samples: int = -1,
+ is_sft_data: bool = False,
):
super().__init__(processor, device, max_length)
+ self.is_sft_data = is_sft_data
self._load_data(data_path, num_samples)
def _load_data(self, data_path: str, num_samples: int):
@@ -74,9 +76,8 @@ def _load_hf_dataset(self, data_path: str, num_samples: int, block_size: int = 2
inputs = {
"input_ids": torch.tensor(result["input_ids"][i]).unsqueeze(0).to(self.device)
}
- labels = inputs["input_ids"].roll(shifts=-1, dims=-1)
- labels[:, -1] = -100
- inputs["labels"] = labels.to(self.device)
+ # HF CausalLM models shift labels internally; feed labels == input_ids.
+ inputs["labels"] = inputs["input_ids"].clone()
inputs["attention_mask"] = torch.tensor(result["attention_mask"][i]).to(self.device)
self.data.append(inputs)
@@ -101,8 +102,8 @@ def _load_parquet_data(self, data_path: str, num_samples: int):
if "labels" in df.columns:
labels = torch.tensor(df["labels"].iloc[i]).unsqueeze(0)
else:
- labels = model_inputs["input_ids"].roll(shifts=-1, dims=-1)
- labels[:, -1] = -100
+ # HF CausalLM models shift labels internally; feed labels == input_ids.
+ labels = model_inputs["input_ids"].clone()
data_item = {
"input_ids": model_inputs["input_ids"].to(self.device),
@@ -128,42 +129,83 @@ def _load_jsonl_data(self, data_path: str, num_samples: int):
# Prepare messages
messages = self._prepare_messages(data)
- # Apply chat template
- text = self.processor.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
+ # Find the LAST assistant turn — loss is computed ONLY on
+ # this reply. Everything before it (system + user(s) +
+ # earlier assistant(s)) serves as prompt context.
+ last_assistant_idx = None
+ for idx, item in enumerate(messages):
+ if item["role"] == "assistant":
+ last_assistant_idx = idx
+ if last_assistant_idx is None:
+ # No assistant turn -> nothing to supervise; skip.
+ continue
+ prompt_messages = messages[:last_assistant_idx]
+ assistant_msg = messages[last_assistant_idx]
+
+ # Tokenize the prompt (up to the generation marker) and the
+ # full conversation separately so we know exactly where the
+ # assistant reply starts.
+ prompt_text = self.processor.apply_chat_template(
+ prompt_messages, tokenize=False, add_generation_prompt=True
+ )
+ full_messages = prompt_messages + [assistant_msg]
+ full_text = self.processor.apply_chat_template(
+ full_messages, tokenize=False, add_generation_prompt=False
)
- thinking_data = False
- for dic in messages:
- if dic["role"] == "assistant":
- if "" and "" in dic["content"]:
- thinking_data = True
- break
+ # Legacy branch: thinking-style data without a chat template.
+ thinking_data = any(
+ m["role"] == "assistant"
+ and "" in m.get("content", "")
+ and "" in m.get("content", "")
+ for m in messages
+ )
if thinking_data:
- text = self.processor.bos_token if self.processor.bos_token is not None else ""
- for dic in messages:
- if dic["role"] == "system":
- text += dic["content"]
- elif dic["role"] == "user":
- text = text + "<|User|>" + dic["content"] + "<|Assistant|>"
- elif dic["role"] == "assistant":
- text = text + dic["content"] + self.processor.eos_token
+ bos = self.processor.bos_token or ""
+ prompt_text = bos
+ for m in prompt_messages:
+ if m["role"] == "system":
+ prompt_text += m["content"]
+ elif m["role"] == "user":
+ prompt_text += "<|User|>" + m["content"] + "<|Assistant|>"
+ elif m["role"] == "assistant":
+ prompt_text += m["content"] + self.processor.eos_token
+ full_text = prompt_text + assistant_msg["content"] + self.processor.eos_token
+
+ # Token-level prompt length: count tokens in ``prompt_text``
+ # without special-token insertion so it aligns with the
+ # prefix of the tokenization of ``full_text``.
+ prompt_ids = self.processor(
+ prompt_text,
+ add_special_tokens=False,
+ return_tensors=None,
+ )["input_ids"]
+ prompt_len = len(prompt_ids)
model_inputs = self.processor(
- text=[text],
+ text=[full_text],
return_tensors="pt",
max_length=self.max_length,
truncation=True,
padding="max_length",
)
- labels = model_inputs["input_ids"].roll(shifts=-1, dims=-1)
- labels[:, -1] = -100
+ # Build labels: HF CausalLM shifts labels internally, so
+ # the label at position ``t`` supervises the prediction of
+ # ``input_ids[t+1]``. Positions before (and at) the end of
+ # the prompt are set to -100 so they contribute no loss.
+ input_ids = model_inputs["input_ids"]
+ attention_mask = model_inputs["attention_mask"]
+ labels = input_ids.clone()
+ if self.is_sft_data:
+ labels[:, :prompt_len] = -100
+ # Also mask padding tokens.
+ labels[attention_mask == 0] = -100
self.data.append(
{
- "input_ids": model_inputs["input_ids"].to(self.device),
- "attention_mask": model_inputs["attention_mask"].to(self.device),
+ "input_ids": input_ids.to(self.device),
+ "attention_mask": attention_mask.to(self.device),
"labels": labels.to(self.device),
}
)
diff --git a/angelslim/engine.py b/angelslim/engine.py
index 757d8741..8c5344c2 100644
--- a/angelslim/engine.py
+++ b/angelslim/engine.py
@@ -101,6 +101,13 @@ def prepare_model(
assert model_name, "model_name must be specified."
assert model_path, "model_path must be specified."
+ # Normalize device_map for DeepSpeed ZeRO / distributed training: YAML
+ # configs often write ``None`` / ``"None"`` / ``"distributed"`` to
+ # mean "no pre-placement, let DeepSpeed shard". HF only accepts
+ # Python ``None`` there.
+ if isinstance(device_map, str) and device_map.lower() in ("none", "distributed"):
+ device_map = None
+
# Initialize slim model by ModelFactory
self.slim_model = SlimModelFactory.create(
model_name, model=model, deploy_backend=deploy_backend
@@ -152,6 +159,7 @@ def prepare_data(
use_audio_in_video=False,
model_name=None,
quantization_config=None,
+ is_sft_data=False,
) -> Optional[Any]:
"""Prepare compression dataset"""
if custom_dataloader is not None:
@@ -178,6 +186,7 @@ def prepare_data(
use_audio_in_video=use_audio_in_video,
model_name=model_name,
quantization_config=quantization_config,
+ is_sft_data=is_sft_data,
)
self.max_seq_length = max_length
diff --git a/angelslim/models/base_model.py b/angelslim/models/base_model.py
index 3f15befb..39ee6ece 100644
--- a/angelslim/models/base_model.py
+++ b/angelslim/models/base_model.py
@@ -24,7 +24,13 @@
from ..compressor.quant.core import QuantConfig
from ..compressor.quant.modules import NVFP4QDQModule, QDQModule
-from ..utils import common_prefix, print_info
+from ..utils import (
+ common_prefix,
+ is_deepspeed_zero3_enabled,
+ print_info,
+ stream_load_weights,
+ zero3_empty_model_from_pretrained,
+)
__all__ = ["BaseLLMModel"]
@@ -65,6 +71,27 @@ def from_pretrained(
using_multi_nodes=False,
attn_implementation="default",
):
+ # DeepSpeed ZeRO-3 path: build an empty sharded model on every rank
+ # via deepspeed.zero.Init, linearize fused MoE experts, then stream
+ # the safetensors checkpoint into the (sharded) parameters. This
+ # avoids HF's path that materialises the full state_dict on every
+ # rank's CPU before sharding.
+ if is_deepspeed_zero3_enabled():
+ log_prefix = f"[{type(self).__name__}.from_pretrained]"
+ self.model = zero3_empty_model_from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ use_cache=use_cache,
+ attn_implementation=attn_implementation,
+ log_prefix=log_prefix,
+ )
+ stream_load_weights(self.model, model_path, log_prefix=log_prefix)
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=trust_remote_code
+ )
+ return
+
kwargs = dict(
torch_dtype=torch_dtype,
device_map=device_map,
diff --git a/angelslim/models/llm/hunyuan_v3_moe.py b/angelslim/models/llm/hunyuan_v3_moe.py
index d8919643..fdfd9c5c 100644
--- a/angelslim/models/llm/hunyuan_v3_moe.py
+++ b/angelslim/models/llm/hunyuan_v3_moe.py
@@ -16,20 +16,50 @@
import torch
import torch.nn as nn
-from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.hy_v3.modeling_hy_v3 import (
ALL_ATTENTION_FUNCTIONS,
HYV3Experts,
+ HYV3TopKRouter,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ...compressor.quant.core import PTQSaveVllmHF
+from ...utils import is_deepspeed_zero3_enabled
from ...utils.utils import find_layers, find_parent_layer_and_sub_name
from ..base_model import BaseLLMModel
from ..model_factory import SlimModelFactory
+def _patch_hyv3_router_for_zero3():
+ if getattr(HYV3TopKRouter, "_angelslim_zero3_dtype_patch", False):
+ return
+
+ def patched_forward(
+ self,
+ hidden_states: torch.Tensor,
+ e_score_correction_bias: torch.Tensor,
+ ) -> tuple:
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = nn.functional.linear(
+ hidden_states.to(self.weight.dtype),
+ self.weight,
+ ).to(torch.float32)
+ routing_weights = torch.sigmoid(router_logits)
+
+ scores_for_choice = routing_weights + e_score_correction_bias
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
+ top_k_weights = routing_weights.gather(1, top_k_index)
+
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20)
+ top_k_weights = top_k_weights * self.router_scaling_factor
+
+ return router_logits, top_k_weights, top_k_index
+
+ HYV3TopKRouter.forward = patched_forward
+ HYV3TopKRouter._angelslim_zero3_dtype_patch = True
+
+
class HYV3ExpertsWithLinear(HYV3Experts):
"""Wrapper around HYV3Experts that exposes per-expert weights as nn.Linear modules.
@@ -138,19 +168,18 @@ def from_pretrained(
):
attn_implementation = "eager"
torch_dtype = torch.bfloat16
- self.model = AutoModelForCausalLM.from_pretrained(
- model_path,
- attn_implementation=attn_implementation,
+ if is_deepspeed_zero3_enabled():
+ _patch_hyv3_router_for_zero3()
+
+ super().from_pretrained(
+ model_path=model_path,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
use_cache=use_cache,
- )
-
- # Load tokenizer
- self.tokenizer = AutoTokenizer.from_pretrained(
- model_path, trust_remote_code=trust_remote_code
+ using_multi_nodes=using_multi_nodes,
+ attn_implementation=attn_implementation,
)
def replace_moe(self):
@@ -179,9 +208,9 @@ def get_observer_layers(self):
"mlp.gate_proj",
"mlp.up_proj",
"mlp.down_proj",
- "shared_mlp.gate_proj",
- "shared_mlp.up_proj",
- "shared_mlp.down_proj",
+ "shared_experts.gate_proj",
+ "shared_experts.up_proj",
+ "shared_experts.down_proj",
]
expert_pattern = [
r"model\.layers\.\d+\.mlp\.experts\.\d+\.gate_proj",
@@ -326,11 +355,9 @@ def patched_forward(
key_states, value_states, attn_module.layer_idx, cache_kwargs
)
- attention_interface = eager_attention_forward
- if attn_module.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[
- attn_module.config._attn_implementation
- ]
+ attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
+ attn_module.config._attn_implementation, eager_attention_forward
+ )
attn_output, attn_weights = attention_interface(
attn_module,
diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py
index d80be244..a12f24b5 100644
--- a/angelslim/utils/__init__.py
+++ b/angelslim/utils/__init__.py
@@ -30,3 +30,4 @@
from .utils import print_with_rank # noqa: F401
from .utils import rank0_print # noqa: F401
from .utils import set_op_by_name # noqa: F401
+from .zero3_io import * # noqa: F401 F403
diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py
index 77481b9e..31b5acdc 100644
--- a/angelslim/utils/config_parser.py
+++ b/angelslim/utils/config_parser.py
@@ -185,6 +185,7 @@ class DatasetConfig:
batch_size: int = field(default=1)
shuffle: bool = field(default=False)
inference_settings: Optional[Dict[str, Any]] = field(default=None)
+ is_sft_data: bool = field(default=False)
@dataclass
@@ -272,10 +273,23 @@ class QATTrainingConfig:
hf_dataset: Optional[str] = None
do_train: bool = field(default=True)
resume_ckpt_dir: Optional[str] = None
+ # Optional warm-start directory (a previous ``save_fmt="real"`` output).
+ # When set, scales / zero_points / kv-cache scales are loaded from this
+ # directory into the freshly-created ``QuantLinear`` quantizers; base
+ # ``Linear`` weights still come from ``model.model_path``. Required when
+ # DeepSpeed ZeRO-3 is enabled (no calibration is possible on shards).
+ from_ptq_ckpt: Optional[str] = None
loss_type: str = field(default="origin")
loss_topk: Optional[int] = None
kd_temperature: float = field(default=1.0)
kd_alpha: float = field(default=0.5)
+ # ---- new loss-weight controls (compose LM + KD loss) ----
+ # lm_loss_weight: weight on the HF CausalLM CE loss (labels must be set).
+ # kd_loss_weight: weight on the chosen distillation loss (kl / rkl / mse
+ # / kl_top_K / r_kl_top_K / cakld / jsd ...). Only loss components
+ # with a strictly-positive weight are computed AND logged.
+ lm_loss_weight: float = field(default=1.0)
+ kd_loss_weight: float = field(default=0.0)
hf_args: Dict[str, Any] = field(default_factory=dict)
diff --git a/angelslim/utils/utils.py b/angelslim/utils/utils.py
index 36422758..e10f2e98 100644
--- a/angelslim/utils/utils.py
+++ b/angelslim/utils/utils.py
@@ -49,10 +49,18 @@ def set_op_by_name(layer, name, new_module):
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels) - 1):
- if levels[l_idx].isdigit():
- mod_ = mod_[int(levels[l_idx])]
+ part = levels[l_idx]
+ if part.isdigit():
+ # Prefer integer indexing for nn.ModuleList / nn.Sequential;
+ # fall back to getattr for custom containers (e.g. our
+ # LinearizedMoeExperts that registers experts via
+ # ``setattr(self, str(idx), ...)``).
+ try:
+ mod_ = mod_[int(part)]
+ except (TypeError, IndexError, KeyError):
+ mod_ = getattr(mod_, part)
else:
- mod_ = getattr(mod_, levels[l_idx])
+ mod_ = getattr(mod_, part)
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
diff --git a/angelslim/utils/zero3_io.py b/angelslim/utils/zero3_io.py
new file mode 100644
index 00000000..614ee5bd
--- /dev/null
+++ b/angelslim/utils/zero3_io.py
@@ -0,0 +1,853 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""All-in-one DeepSpeed ZeRO-3 helpers for AngelSlim QAT.
+
+This module concentrates everything ZeRO-3-specific so that the rest of the
+codebase touches only a few thin call-sites:
+
+ * ZeRO-3 detection / parameter gathering helpers.
+ * Empty model construction under ``deepspeed.zero.Init`` plus a generic
+ "linearize fused MoE experts" pass that builds *empty* per-expert
+ ``nn.Linear`` modules (no copy from the old fused tensor required —
+ weights are filled later from the safetensors checkpoint).
+ * Streaming weight loader: walks the safetensors shards once, on rank 0
+ only, and broadcasts each tensor into the (possibly sharded) target
+ parameter via ``GatheredParameters(modifier_rank=0)``. Handles fused
+ MoE keys (``...experts.gate_up_proj``, ``...experts.down_proj``) by
+ slicing per expert into the linearized targets.
+ * Streaming scale loader for QAT warm-start: reads ``*.weight_scale``,
+ ``*.input_scale``, ``*.k_cache.scale``, ``*.v_cache.scale`` keys from a
+ PTQ "real" checkpoint and writes them into the freshly-created
+ ``QuantLinear`` quantizer parameters.
+ * Saving: gather a sharded model into a rank-0 CPU state_dict and call
+ the model-specific save_func by patching ``state_dict``.
+ * Optimizer-side patches needed because QAT scale parameters are tied
+ across multiple modules in a layer.
+
+By design, **nothing in this file mutates the model when ZeRO-3 is not
+enabled** (each helper is a no-op or behaves identically to the
+non-distributed path). Importing this module is therefore safe in any
+configuration.
+"""
+
+from __future__ import annotations
+
+import gc
+import glob
+import json
+import os
+from contextlib import contextmanager, nullcontext
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from .lazy_imports import deepspeed
+from .utils import find_parent_layer_and_sub_name, print_info
+
+ZERO3_PARAM_ATTRS = ("ds_id", "ds_status", "ds_numel", "ds_tensor")
+
+
+# ---------------------------------------------------------------------------
+# Basic detection / context helpers
+# ---------------------------------------------------------------------------
+
+
+def is_deepspeed_zero3_enabled() -> bool:
+ """True iff HuggingFace's ``HfTrainerDeepSpeedConfig`` is registered with
+ ZeRO stage 3. Returns False if ``transformers``/``deepspeed`` is not
+ importable."""
+ try:
+ from transformers.integrations.deepspeed import (
+ is_deepspeed_zero3_enabled as _hf,
+ )
+
+ return bool(_hf())
+ except Exception: # noqa: BLE001
+ return False
+
+
+def is_zero3_param(x) -> bool:
+ """True iff ``x`` is a ZeRO-3 sharded parameter (``deepspeed.zero.Init``
+ has injected its bookkeeping attributes)."""
+ if not isinstance(x, torch.nn.Parameter):
+ return False
+ return any(hasattr(x, attr) for attr in ZERO3_PARAM_ATTRS)
+
+
+@contextmanager
+def gathered_param_if_zero3(x, modifier_rank: Optional[int] = None):
+ """All-gather a ZeRO-3 shard for the lifetime of the block.
+
+ Pure no-op (yields ``x`` unchanged) when ``x`` is not a ZeRO-3 shard, so
+ callers can always wrap their critical sections without branching.
+ """
+ if is_zero3_param(x):
+ ctx = deepspeed.zero.GatheredParameters([x], modifier_rank=modifier_rank)
+ else:
+ ctx = nullcontext()
+ with ctx:
+ yield x
+
+
+@contextmanager
+def gathered_params_if_zero3(params, modifier_rank: Optional[int] = None):
+ """Batched variant of :func:`gathered_param_if_zero3`."""
+ params = [p for p in params if p is not None]
+ z3 = [p for p in params if is_zero3_param(p)]
+ if z3:
+ ctx = deepspeed.zero.GatheredParameters(z3, modifier_rank=modifier_rank)
+ else:
+ ctx = nullcontext()
+ with ctx:
+ yield params
+
+
+def model_has_zero3_params(model) -> bool:
+ return any(is_zero3_param(p) for p in model.parameters())
+
+
+def _rank() -> int:
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ return torch.distributed.get_rank()
+ return 0
+
+
+def _cleanup() -> None:
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+# ---------------------------------------------------------------------------
+# Empty MoE linearization
+# ---------------------------------------------------------------------------
+
+
+class LinearizedMoeExperts(nn.Module):
+ """Empty per-expert ``nn.Linear`` container for fused MoE experts.
+
+ This module mirrors the ``forward`` of HuggingFace's fused experts but
+ holds parameters as ``num_experts`` triplets of ``(gate_proj, up_proj,
+ down_proj)`` ``nn.Linear`` modules. Construction is **purely structural**
+ — no weight copy from the old fused tensor — so it is safe to instantiate
+ under ``deepspeed.zero.Init`` and let the streaming loader fill in the
+ weights afterwards.
+ """
+
+ _angelslim_linearized_moe = True
+
+ def __init__(
+ self,
+ num_experts: int,
+ hidden_dim: int,
+ intermediate_dim: int,
+ act_fn,
+ dtype=torch.bfloat16,
+ device=None,
+ config=None,
+ ):
+ super().__init__()
+ self.num_experts = int(num_experts)
+ self.hidden_dim = int(hidden_dim)
+ self.intermediate_dim = int(intermediate_dim)
+ self.act_fn = act_fn
+ if config is not None:
+ self.config = config
+
+ if device is None or (isinstance(device, torch.device) and device.type == "meta"):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ for expert_idx in range(self.num_experts):
+ expert = nn.ModuleDict(
+ {
+ "gate_proj": nn.Linear(
+ self.hidden_dim,
+ self.intermediate_dim,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ ),
+ "up_proj": nn.Linear(
+ self.hidden_dim,
+ self.intermediate_dim,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ ),
+ "down_proj": nn.Linear(
+ self.intermediate_dim,
+ self.hidden_dim,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ ),
+ }
+ )
+ setattr(self, str(expert_idx), expert)
+
+ def __getitem__(self, idx):
+ return getattr(self, str(idx))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ top_k_index: torch.Tensor,
+ top_k_weights: torch.Tensor,
+ ) -> torch.Tensor:
+ final_hidden_states = torch.zeros_like(hidden_states)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ expert_hit_mask = torch.greater(expert_mask.sum(dim=(-1, -2)), 0)
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ # ZeRO-3 gathers parameters when a Linear is entered. All ranks must
+ # enter the same experts in the same order even when their local
+ # batches route to different experts, otherwise collectives deadlock.
+ expert_hit_int = expert_hit_mask.to(torch.int32)
+ torch.distributed.all_reduce(
+ expert_hit_int,
+ op=torch.distributed.ReduceOp.MAX,
+ )
+ expert_hit_mask = expert_hit_int.to(torch.bool)
+ expert_hit = expert_hit_mask.nonzero()
+
+ for expert_idx in expert_hit:
+ expert_idx = expert_idx[0]
+ if expert_idx == self.num_experts:
+ continue
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ expert_layer = getattr(self, str(int(expert_idx.item())))
+ gate = expert_layer["gate_proj"](current_state)
+ up = expert_layer["up_proj"](current_state)
+ current_hidden_states = self.act_fn(gate) * up
+ current_hidden_states = expert_layer["down_proj"](current_hidden_states)
+ current_hidden_states = (
+ current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
+ )
+ final_hidden_states.index_add_(
+ 0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
+ )
+
+ return final_hidden_states
+
+
+def _is_fused_moe_experts(module) -> bool:
+ """Heuristic: matches HF Qwen3MoeExperts / HYV3Experts / similar.
+
+ They all expose ``gate_up_proj`` / ``down_proj`` as ``nn.Parameter``
+ plus ``num_experts``, ``hidden_dim``, ``intermediate_dim``, ``act_fn``."""
+ if isinstance(module, LinearizedMoeExperts):
+ return False
+ required = (
+ "gate_up_proj",
+ "down_proj",
+ "num_experts",
+ "hidden_dim",
+ "intermediate_dim",
+ "act_fn",
+ )
+ return all(hasattr(module, a) for a in required)
+
+
+def _ds_full_shape(param):
+ """Full shape of a parameter, accounting for ZeRO-3 sharding."""
+ shape = getattr(param, "ds_shape", None)
+ if shape is None:
+ shape = param.shape
+ return tuple(int(x) for x in shape)
+
+
+def linearize_moe_experts_empty(model, dtype=torch.bfloat16) -> int:
+ """Replace every fused MoE experts module in ``model`` with an empty
+ :class:`LinearizedMoeExperts`.
+
+ Under ZeRO-3 the new ``nn.Linear`` parameters are created inside a
+ ``deepspeed.zero.Init`` context so they get partitioned immediately.
+ Weights are NOT copied from the old fused tensors — the streaming
+ safetensors loader is responsible for that. The old module is dropped
+ afterwards.
+ """
+ targets = []
+ for name, module in model.named_modules():
+ if _is_fused_moe_experts(module):
+ targets.append(name)
+
+ if not targets:
+ return 0
+
+ z3 = is_deepspeed_zero3_enabled()
+ replaced = 0
+ for name in targets:
+ parent, sub_name = find_parent_layer_and_sub_name(model, name)
+ old = getattr(parent, sub_name)
+
+ # Resolve dimensions from the (possibly sharded) old tensors.
+ gate_up_shape = _ds_full_shape(old.gate_up_proj)
+ down_shape = _ds_full_shape(old.down_proj)
+ num_experts = int(old.num_experts)
+ hidden_dim = int(old.hidden_dim)
+ # gate_up_proj: [num_experts, 2*intermediate_dim, hidden_dim]
+ if len(gate_up_shape) >= 3 and gate_up_shape[-2] % 2 == 0:
+ intermediate_dim = gate_up_shape[-2] // 2
+ elif len(down_shape) >= 3:
+ intermediate_dim = int(down_shape[-1])
+ else:
+ intermediate_dim = int(old.intermediate_dim)
+
+ old_dtype = old.gate_up_proj.dtype if hasattr(old.gate_up_proj, "dtype") else dtype
+ config = getattr(old, "config", None)
+ act_fn = old.act_fn
+
+ ctx = deepspeed.zero.Init() if z3 else nullcontext()
+ with ctx:
+ new_module = LinearizedMoeExperts(
+ num_experts=num_experts,
+ hidden_dim=hidden_dim,
+ intermediate_dim=intermediate_dim,
+ act_fn=act_fn,
+ dtype=old_dtype,
+ device=None, # let LinearizedMoeExperts pick cuda
+ config=config,
+ )
+
+ setattr(parent, sub_name, new_module)
+ del old
+ replaced += 1
+ _cleanup()
+
+ print_info(f"[zero3] linearize_moe_experts_empty: replaced {replaced} fused module(s).")
+ return replaced
+
+
+# ---------------------------------------------------------------------------
+# Empty model construction + streaming weight loader
+# ---------------------------------------------------------------------------
+
+
+def _resolve_dtype(torch_dtype, config):
+ if isinstance(torch_dtype, torch.dtype):
+ return torch_dtype
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
+ return getattr(torch, torch_dtype)
+ resolved = getattr(config, "torch_dtype", None) or torch.float32
+ if isinstance(resolved, str):
+ return getattr(torch, resolved)
+ return resolved
+
+
+def zero3_empty_model_from_pretrained(
+ model_path,
+ torch_dtype="auto",
+ trust_remote_code=True,
+ use_cache=False,
+ attn_implementation="default",
+ log_prefix="[zero3]",
+):
+ """Build an EMPTY ZeRO-3 sharded model from a HuggingFace ``model_path``.
+
+ Linearizes all fused MoE experts immediately so subsequent QuantLinear
+ insertion can iterate flat ``nn.Linear`` modules. Does NOT load
+ weights — caller must invoke :func:`stream_load_weights`.
+ """
+ from transformers import AutoConfig, AutoModelForCausalLM
+ from transformers.integrations.deepspeed import HfDeepSpeedConfig # noqa: F401
+
+ # ``no_init_weights`` / ``no_tie_weights`` moved across transformers
+ # versions: newest (>=5.x) expose them under ``transformers.initialization``;
+ # older releases kept them in ``modeling_utils`` or the top-level package.
+ no_init_weights = None
+ no_tie_weights = None
+ for mod in ("transformers.initialization", "transformers.modeling_utils", "transformers"):
+ try:
+ m = __import__(mod, fromlist=["no_init_weights"])
+ if hasattr(m, "no_init_weights"):
+ no_init_weights = m.no_init_weights
+ if hasattr(m, "no_tie_weights"):
+ no_tie_weights = m.no_tie_weights
+ except Exception: # noqa: BLE001
+ continue
+ if no_init_weights is not None and no_tie_weights is not None:
+ break
+ if no_init_weights is None:
+ no_init_weights = nullcontext # type: ignore
+ if no_tie_weights is None:
+ no_tie_weights = nullcontext # type: ignore
+
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
+ if attn_implementation != "default":
+ config._attn_implementation = attn_implementation
+ if use_cache is not None:
+ config.use_cache = use_cache
+
+ resolved = _resolve_dtype(torch_dtype, config)
+ print_info(
+ f"{log_prefix} build empty ZeRO-3 model dtype={resolved} from config "
+ f"{getattr(config, 'model_type', '?')}"
+ )
+
+ # ``from_config`` triggers ``deepspeed.zero.Init`` automatically when
+ # ``HfTrainerDeepSpeedConfig`` is registered (i.e. when
+ # is_deepspeed_zero3_enabled() returns True).
+ with no_init_weights(), no_tie_weights():
+ model = AutoModelForCausalLM.from_config(
+ config,
+ torch_dtype=resolved,
+ trust_remote_code=trust_remote_code,
+ )
+
+ # Linearize fused MoE experts BEFORE weight loading so the loader can
+ # write per-expert slices directly into the new Linear targets.
+ linearize_moe_experts_empty(model, dtype=resolved)
+
+ return model
+
+
+def _shards(model_path):
+ """Yield ``(shard_path, [keys])`` for every safetensors shard."""
+ from safetensors import safe_open
+
+ index_path = os.path.join(model_path, "model.safetensors.index.json")
+ if os.path.isfile(index_path):
+ with open(index_path, "r") as f:
+ weight_map = json.load(f)["weight_map"]
+ per_shard = {}
+ for key, shard in weight_map.items():
+ per_shard.setdefault(shard, []).append(key)
+ for shard in sorted(per_shard):
+ yield os.path.join(model_path, shard), per_shard[shard]
+ return
+
+ paths = sorted(glob.glob(os.path.join(model_path, "*.safetensors")))
+ if not paths:
+ raise FileNotFoundError(f"No safetensors found under {model_path}")
+ for shard_path in paths:
+ with safe_open(shard_path, framework="pt") as r:
+ yield shard_path, list(r.keys())
+
+
+def _broadcast_into_target(src, target, *, is_buffer=False, key=None):
+ """Copy ``src`` (rank0 only, or None on other ranks) into ``target``.
+
+ Handles three cases:
+ * ZeRO-3 sharded ``Parameter``: gather, rank0 writes, exit gather.
+ * Regular distributed ``Parameter`` / replicated buffer: rank0 stages,
+ then broadcast.
+ * Single-process: direct copy.
+ """
+ dist_active = torch.distributed.is_available() and torch.distributed.is_initialized()
+
+ if is_zero3_param(target):
+ with gathered_param_if_zero3(target, modifier_rank=0):
+ if _rank() == 0:
+ if src is None or src.shape != target.shape:
+ return False
+ target.data.copy_(src.to(device=target.device, dtype=target.dtype))
+ return True
+
+ # Regular tensor (parameter or buffer).
+ if dist_active:
+ if _rank() == 0:
+ if src is None or (not is_buffer and src.shape != target.shape):
+ return False
+ tmp = src.to(device=target.device, dtype=target.dtype).contiguous()
+ else:
+ tmp = torch.empty_like(target)
+ torch.distributed.broadcast(tmp, src=0)
+ target.data.copy_(tmp)
+ return True
+
+ if src is None:
+ return False
+ target.data.copy_(src.to(device=target.device, dtype=target.dtype))
+ return True
+
+
+def stream_load_weights(model, model_path, log_prefix="[zero3]"):
+ """Stream a HF safetensors checkpoint into ``model``.
+
+ Recognises fused MoE keys ``*.experts.gate_up_proj`` and
+ ``*.experts.down_proj`` and dispatches the per-expert slices into the
+ matching :class:`LinearizedMoeExperts` children. All other keys are
+ matched against ``model.named_parameters()`` / ``named_buffers()``.
+
+ rank0 reads the bytes; ZeRO-3 sharded targets are filled inside
+ ``GatheredParameters(modifier_rank=0)``; replicated tensors are
+ broadcast.
+ """
+ from safetensors import safe_open
+
+ name_to_param = dict(model.named_parameters())
+ name_to_buffer = dict(model.named_buffers())
+ rank = _rank()
+
+ loaded = 0
+ skipped = 0
+ seen_targets = set()
+
+ for shard_path, keys in _shards(model_path):
+ with safe_open(shard_path, framework="pt") as reader:
+ for key in keys:
+ if key.endswith(".experts.gate_up_proj"):
+ base = key[: -len(".gate_up_proj")]
+ src = reader.get_tensor(key) if rank == 0 else None
+ n_exp = (
+ int(src.shape[0])
+ if src is not None
+ else _infer_num_experts(base, name_to_param)
+ )
+ for i in range(n_exp):
+ gkey = f"{base}.{i}.gate_proj.weight"
+ ukey = f"{base}.{i}.up_proj.weight"
+ gtgt = name_to_param.get(gkey)
+ utgt = name_to_param.get(ukey)
+ if gtgt is None or utgt is None:
+ skipped += 2
+ continue
+ gsrc = src[i].chunk(2, dim=-2)[0] if src is not None else None
+ usrc = src[i].chunk(2, dim=-2)[1] if src is not None else None
+ if _broadcast_into_target(gsrc, gtgt, key=gkey):
+ seen_targets.add(gkey)
+ loaded += 1
+ else:
+ skipped += 1
+ if _broadcast_into_target(usrc, utgt, key=ukey):
+ seen_targets.add(ukey)
+ loaded += 1
+ else:
+ skipped += 1
+ del src
+ elif key.endswith(".experts.down_proj"):
+ base = key[: -len(".down_proj")]
+ src = reader.get_tensor(key) if rank == 0 else None
+ n_exp = (
+ int(src.shape[0])
+ if src is not None
+ else _infer_num_experts(base, name_to_param)
+ )
+ for i in range(n_exp):
+ dkey = f"{base}.{i}.down_proj.weight"
+ dtgt = name_to_param.get(dkey)
+ if dtgt is None:
+ skipped += 1
+ continue
+ dsrc = src[i] if src is not None else None
+ if _broadcast_into_target(dsrc, dtgt, key=dkey):
+ seen_targets.add(dkey)
+ loaded += 1
+ else:
+ skipped += 1
+ del src
+ else:
+ tgt = name_to_param.get(key)
+ is_buf = False
+ if tgt is None:
+ tgt = name_to_buffer.get(key)
+ is_buf = tgt is not None
+ if tgt is None:
+ skipped += 1
+ continue
+ src = reader.get_tensor(key) if rank == 0 else None
+ if _broadcast_into_target(src, tgt, is_buffer=is_buf, key=key):
+ seen_targets.add(key)
+ loaded += 1
+ else:
+ skipped += 1
+ del src
+ _cleanup()
+ print_info(f"{log_prefix} loaded shard {os.path.basename(shard_path)}")
+
+ all_targets = set(name_to_param) | set(name_to_buffer)
+ missing = sorted(all_targets - seen_targets)
+ print_info(
+ f"{log_prefix} stream_load_weights done: "
+ f"loaded={loaded} skipped={skipped} missing={len(missing)}"
+ )
+ if missing:
+ print_info(f"{log_prefix} first missing keys: {missing[:10]}")
+
+ try:
+ model.tie_weights()
+ except Exception as e: # noqa: BLE001
+ print_info(f"{log_prefix} tie_weights skipped: {e}")
+
+
+def _infer_num_experts(base, name_to_param):
+ prefix = f"{base}."
+ ids = []
+ for name in name_to_param:
+ if not name.startswith(prefix):
+ continue
+ first = name[len(prefix) :].split(".", 1)[0]
+ if first.isdigit():
+ ids.append(int(first))
+ return (max(ids) + 1) if ids else 0
+
+
+# ---------------------------------------------------------------------------
+# Streaming PTQ-scale loader for QAT warm-start
+# ---------------------------------------------------------------------------
+
+
+_SCALE_SUFFIX_RULES = [
+ # (suffix_in_ckpt, quantizer_attr_on_QuantLinear, sub_attr, layer_name_rewrite)
+ (".weight_zero_point", "weight_quantizer", "zero_point", None),
+ (".input_zero_point", "act_quantizer", "zero_point", None),
+ (".weight_scale", "weight_quantizer", "scale", None),
+ (".input_scale", "act_quantizer", "scale", None),
+ (".k_cache.scale", "qkv_quantizer", "scale", ".k_proj"),
+ (".v_cache.scale", "qkv_quantizer", "scale", ".v_proj"),
+]
+# Longest first to avoid '.scale' winning over '.k_cache.scale'.
+_SCALE_SUFFIX_RULES.sort(key=lambda r: len(r[0]), reverse=True)
+
+
+def _parse_scale_key(key):
+ for suffix, qname, sub, rewrite in _SCALE_SUFFIX_RULES:
+ if key.endswith(suffix):
+ base = key[: -len(suffix)]
+ return (base + rewrite if rewrite else base), qname, sub
+ return None
+
+
+def _expand_scale_targets(layer_name, qname, sub, named_modules):
+ """Expand a checkpoint key that targets a fused MoE expert tensor into the
+ matching per-expert linears in the linearized model. For non-MoE keys
+ returns ``[(layer_name, qname, sub)]`` unchanged."""
+ if layer_name in named_modules:
+ return [(layer_name, qname, sub)]
+
+ # PTQ checkpoint may store scales for the fused expert matrix; map them
+ # to every per-expert Linear we expanded into.
+ if layer_name.endswith(".experts.gate_up_proj"):
+ base = layer_name[: -len(".gate_up_proj")]
+ return [
+ (n, qname, sub)
+ for n in named_modules
+ if n.startswith(base + ".") and (n.endswith(".gate_proj") or n.endswith(".up_proj"))
+ ]
+ if layer_name.endswith(".experts.down_proj"):
+ base = layer_name[: -len(".down_proj")]
+ return [
+ (n, qname, sub)
+ for n in named_modules
+ if n.startswith(base + ".") and n.endswith(".down_proj")
+ ]
+ return []
+
+
+def _copy_scale_into(src, target):
+ """rank0-driven copy of a scale-like tensor into a (possibly ZeRO-3)
+ Parameter, with shape coercion for scalar/per-tensor mismatch."""
+ rank = _rank()
+ ok = True
+ with gathered_param_if_zero3(target, modifier_rank=0):
+ if rank == 0:
+ s = src
+ if s.numel() == target.numel():
+ s = s.reshape(target.shape)
+ else:
+ try:
+ s = s.expand_as(target).contiguous()
+ except RuntimeError:
+ ok = False
+ if ok:
+ target.data.copy_(s.to(device=target.device, dtype=target.dtype))
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ device = (
+ torch.device("cuda", torch.cuda.current_device())
+ if torch.cuda.is_available()
+ else target.device
+ )
+ flag = torch.tensor(int(ok), device=device)
+ torch.distributed.broadcast(flag, src=0)
+ ok = bool(flag.item())
+ return ok
+
+
+def stream_load_scales(model, ckpt_dir, log_prefix="[zero3]"):
+ """Read a PTQ "real" checkpoint and write its scale / zero_point /
+ kv-cache scale tensors into the matching ``QuantLinear`` quantizer
+ parameters of ``model``.
+
+ Sets ``act_quantizer.init = True`` for every static activation
+ quantizer that successfully receives a scale, so the lazy-init pass is
+ skipped.
+ """
+ from safetensors import safe_open
+
+ # Resolve nested layout (some PTQ exporters nest under final_quant_checkpoint/).
+ if not glob.glob(os.path.join(ckpt_dir, "*.safetensors")):
+ nested = os.path.join(ckpt_dir, "final_quant_checkpoint")
+ if os.path.isdir(nested):
+ ckpt_dir = nested
+
+ files = sorted(glob.glob(os.path.join(ckpt_dir, "*.safetensors")))
+ if not files:
+ raise FileNotFoundError(f"No *.safetensors in {ckpt_dir}")
+
+ # Lazy import to avoid circular dependency: this module is imported by
+ # angelslim.utils, which is imported very early.
+ from ..compressor.qat.modules.quantizer import QuantLinear
+
+ named_modules = dict(model.named_modules())
+ rank = _rank()
+ loaded = 0
+ skipped = 0
+
+ for src_file in files:
+ with safe_open(src_file, framework="pt") as reader:
+ for key in reader.keys():
+ parsed = _parse_scale_key(key)
+ if parsed is None:
+ continue
+ layer_name, qname, sub = parsed
+ targets = _expand_scale_targets(layer_name, qname, sub, named_modules)
+ if not targets:
+ skipped += 1
+ continue
+ src = reader.get_tensor(key) if rank == 0 else None
+ for tgt_layer, tgt_qname, tgt_sub in targets:
+ module = named_modules.get(tgt_layer)
+ if not isinstance(module, QuantLinear):
+ skipped += 1
+ continue
+ quantizer = getattr(module, tgt_qname, None)
+ if quantizer is None:
+ skipped += 1
+ continue
+ target = getattr(quantizer, tgt_sub, None)
+ if not isinstance(target, torch.nn.Parameter):
+ skipped += 1
+ continue
+ if _copy_scale_into(src, target):
+ loaded += 1
+ if tgt_qname == "act_quantizer" and tgt_sub == "scale":
+ quantizer.init = True
+ else:
+ skipped += 1
+
+ print_info(
+ f"{log_prefix} stream_load_scales: loaded={loaded} skipped={skipped} from {ckpt_dir}"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Saving a sharded model via the model-specific save_func
+# ---------------------------------------------------------------------------
+
+
+def consolidated_state_dict(model):
+ """rank-0 CPU state_dict for a possibly ZeRO-3 sharded ``model``.
+
+ Other ranks see an empty dict (matching the contract of HF/Trainer
+ save callbacks). Includes persistent buffers."""
+ rank = _rank()
+ sd = {}
+ for name, param in model.named_parameters():
+ with gathered_param_if_zero3(param):
+ if rank == 0:
+ sd[name] = param.detach().cpu().clone()
+ if rank == 0:
+ for module_name, module in model.named_modules():
+ for buf_name, buf in module.named_buffers(recurse=False):
+ if buf is None or buf_name in module._non_persistent_buffers_set:
+ continue
+ full = f"{module_name}.{buf_name}" if module_name else buf_name
+ sd[full] = buf.detach().cpu().clone()
+ return sd
+
+
+def save_via_model_save_func(
+ quant_model,
+ save_func,
+ save_target_dir,
+ prebuilt_state_dict=None,
+):
+ """Invoke ``save_func.save(...)`` with the model's ``state_dict`` patched
+ to return a consolidated rank-0 dict.
+
+ Parameters
+ ----------
+ prebuilt_state_dict : dict | None
+ If provided (rank 0), use this dict directly and skip the per-param
+ gather+clone pass over ``model``. This is the **recommended** path
+ under ZeRO-3 because the caller (typically ``QAT.convert``) has
+ already produced rank0's full state_dict layer-by-layer with bounded
+ peak memory. Other ranks may pass ``None``.
+
+ If ``None``, fall back to ``consolidated_state_dict(model)``, which
+ gathers every parameter once more — avoid this when ``model``
+ already holds large materialised tensors on rank 0 (it would double
+ the peak).
+
+ No-op (delegates straight to ``save_func.save``) when no parameters are
+ sharded.
+ """
+ if not model_has_zero3_params(quant_model.model) and prebuilt_state_dict is None:
+ save_func.save(save_target_dir)
+ return
+
+ rank = _rank()
+ if prebuilt_state_dict is not None:
+ sd = prebuilt_state_dict if rank == 0 else {}
+ else:
+ sd = consolidated_state_dict(quant_model.model)
+
+ hf_model = quant_model.get_model()
+ original = hf_model.state_dict
+
+ def _patched(*args, **kwargs):
+ return sd if rank == 0 else {}
+
+ try:
+ hf_model.state_dict = _patched # type: ignore[method-assign]
+ if rank == 0:
+ save_func.save(save_target_dir)
+ finally:
+ hf_model.state_dict = original # type: ignore[method-assign]
+
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+
+# ---------------------------------------------------------------------------
+# DeepSpeed engine: tolerate scale parameters that appear in multiple param
+# groups (the scale tensor itself is unique, but our optimizer construction
+# may pick it up from multiple QuantLinear children of a shared MoE expert).
+# ---------------------------------------------------------------------------
+
+
+def patch_deepspeed_duplicate_check():
+ """No-op DeepSpeed's ``_check_for_duplicates`` so QAT scale parameters
+ that share storage across modules don't crash optimizer init.
+
+ Idempotent: only patches once per process.
+ """
+ try:
+ from deepspeed.runtime.engine import DeepSpeedEngine
+ except Exception as exc: # noqa: BLE001
+ print_info(f"[zero3] skip duplicate-check patch: {exc}")
+ return
+ if getattr(DeepSpeedEngine, "_angelslim_skip_dup_check", False):
+ return
+
+ def _noop(self, basic_optimizer): # noqa: ARG001
+ return
+
+ DeepSpeedEngine._check_for_duplicates = _noop
+ DeepSpeedEngine._angelslim_skip_dup_check = True
+ print_info("[zero3] patched DeepSpeed _check_for_duplicates (idempotent).")
diff --git a/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml b/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml
new file mode 100644
index 00000000..06558ffd
--- /dev/null
+++ b/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml
@@ -0,0 +1,38 @@
+# Global configuration of pipeline
+global:
+ save_path: ./output_ptq_hy3
+
+# Simplified Configuration for LLM compression
+model:
+ name: HYV3MoE
+ model_path: tencent/Hy3-preview
+ trust_remote_code: true
+ low_cpu_mem_usage: true
+ use_cache: false
+ torch_dtype: auto
+ device_map: auto
+
+# Compression configuration
+compression:
+ name: PTQ
+ quantization:
+ name: fp8_static
+ bits: 8
+ quant_method:
+ weight: "per-tensor"
+ activation: "per-tensor"
+ kv_cache: "per-tensor"
+ ignore_layers:
+ - "lm_head"
+ - "model.embed_tokens"
+ - "gate.weight"
+ cpu_convert: true
+ save_name: "fp8"
+
+# Dataset for calibration
+dataset:
+ name: TextDataset
+ data_path: ./dataset/sharegpt_gpt4/sharegpt_gpt4_256.jsonl
+ max_seq_length: 2048
+ num_samples: 512
+ batch_size: 1
diff --git a/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json b/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json
new file mode 100644
index 00000000..e513dd7b
--- /dev/null
+++ b/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json
@@ -0,0 +1,45 @@
+{
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupDecayLR",
+ "params": {
+ "total_num_steps": "auto",
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "none"
+ },
+ "offload_param": {
+ "device": "none"
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 10,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
diff --git a/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml
new file mode 100644
index 00000000..d110e90d
--- /dev/null
+++ b/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml
@@ -0,0 +1,72 @@
+global:
+ save_path: ./output_hy3_zero3
+
+model:
+ name: HYV3MoE
+ model_path: tencent/Hy3-preview
+ trust_remote_code: true
+ torch_dtype: auto
+ device_map: None
+ low_cpu_mem_usage: true
+ use_cache: false
+
+compression:
+ name: QAT
+ quantization:
+ name: fp8_static
+ bits: 8
+ quant_method:
+ weight: per-tensor
+ activation: per-tensor
+ ignore_layers: ["lm_head", "embed_tokens", "gate.weight"]
+ QAT:
+ hf_dataset: null
+ from_ptq_ckpt: ./output_ptq_hy3/hunyuanv3_a20b_fp8_static
+ training_mode: end2end
+ dist_mode: hf
+ save_format: real
+ do_train: true
+ loss_type: cakld
+ loss_topk: null
+ kd_temperature: 1.0
+ kd_alpha: 0.5
+ lm_loss_weight: 1.0
+ kd_loss_weight: 1.0
+ plugin_config:
+ enable_scale: true
+ quant_config:
+ use_weight_quant: true
+ use_activation_quant: true
+ use_qkv_quant: false
+ weight_scale_init_value: 0.1
+ activation_scale_init_value: 0.1
+ learnable:
+ act_scale: false
+ weight_scale: true
+ kv_scale: false
+ norm: false
+ lwc: false
+ hf_args:
+ bf16: true
+ logging_steps: 1
+ logging_first_step: true
+ per_device_train_batch_size: 1
+ gradient_accumulation_steps: 1
+ learning_rate: 1.0e-6
+ lr_scheduler_type: cosine
+ num_train_epochs: 1
+ max_steps: 3
+ save_strategy: "no"
+ warmup_ratio: 0.0
+ max_grad_norm: 1.0
+ gradient_checkpointing: true
+ gradient_checkpointing_kwargs:
+ use_reentrant: true
+ deepspeed: configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json
+
+dataset:
+ name: TextDataset
+ data_path: ./dataset/sharegpt_gpt4/sharegpt_gpt4_256.jsonl
+ max_seq_length: 2048
+ num_samples: 256
+ batch_size: 1
diff --git a/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json b/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json
new file mode 100644
index 00000000..e513dd7b
--- /dev/null
+++ b/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json
@@ -0,0 +1,45 @@
+{
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupDecayLR",
+ "params": {
+ "total_num_steps": "auto",
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "none"
+ },
+ "offload_param": {
+ "device": "none"
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 10,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml
new file mode 100644
index 00000000..f2940943
--- /dev/null
+++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml
@@ -0,0 +1,86 @@
+global:
+ save_path: ./output
+
+model:
+ name: Qwen
+ # Local directory or HuggingFace repo id of the base Qwen3-30B-A3B model.
+ model_path: Qwen/Qwen3-30B-A3B
+ trust_remote_code: true
+ torch_dtype: auto
+ # ZeRO-3 wants HF to NOT place the model with ``device_map``. Pass
+ # ``None`` here; the Engine normalises the string "None"/"distributed".
+ device_map: None
+ low_cpu_mem_usage: true
+ use_cache: false
+
+compression:
+ name: QAT
+ quantization:
+ name: fp8_static
+ bits: 8
+ quant_method:
+ weight: per-tensor
+ activation: per-tensor
+ ignore_layers: ["lm_head", "embed_tokens", "gate.weight"]
+ QAT:
+ # Leave ``hf_dataset`` unset → the trainer falls back to the local
+ # TextDataset defined in the ``dataset:`` section below.
+ hf_dataset: null
+ # REQUIRED under ZeRO-3: warm-start scales from a previous PTQ "real"
+ # checkpoint. Produce it with e.g.
+ # python tools/run.py -c configs/qwen3/ptq/fp8_static/qwen3-a3b_fp8_static.yaml
+ from_ptq_ckpt: ./output_ptq_30b/qwen3-a3b_fp8_static
+ training_mode: end2end
+ dist_mode: hf
+ save_format: real
+ do_train: true
+ # KD + LM loss combination: loss_type picks the KD variant,
+ # lm_loss_weight / kd_loss_weight control the outer mix.
+ # loss_type choices: kl | rkl | mse | cakld | kd | kl_top_K | r_kl_top_K
+ loss_type: cakld
+ loss_topk: null
+ kd_temperature: 1.0
+ kd_alpha: 0.5
+ lm_loss_weight: 1.0
+ kd_loss_weight: 1.0
+ plugin_config:
+ enable_scale: true
+ quant_config:
+ use_weight_quant: true
+ use_activation_quant: true
+ use_qkv_quant: false
+ # Init values used when weight data is not accessible (e.g. under
+ # ZeRO-3) and for quantizer parameters missing from the PTQ
+ # checkpoint.
+ weight_scale_init_value: 0.1
+ activation_scale_init_value: 0.1
+ learnable:
+ act_scale: false
+ weight_scale: true
+ kv_scale: false
+ norm: false
+ lwc: false
+ hf_args:
+ bf16: true
+ logging_steps: 1
+ logging_first_step: true
+ per_device_train_batch_size: 1
+ gradient_accumulation_steps: 1
+ learning_rate: 1.0e-6
+ lr_scheduler_type: cosine
+ num_train_epochs: 1
+ max_steps: 3
+ save_strategy: "no"
+ warmup_ratio: 0.0
+ max_grad_norm: 1.0
+ gradient_checkpointing: true
+ gradient_checkpointing_kwargs:
+ use_reentrant: true
+ deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json
+
+dataset:
+ name: TextDataset
+ data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl
+ max_seq_length: 4096
+ num_samples: 256
+ batch_size: 1
diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml
new file mode 100644
index 00000000..bb0822b2
--- /dev/null
+++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml
@@ -0,0 +1,73 @@
+global:
+ save_path: ./output_4b_zero3
+
+model:
+ name: Qwen
+ model_path: Qwen/Qwen3-4B
+ trust_remote_code: true
+ torch_dtype: auto
+ device_map: None
+ low_cpu_mem_usage: true
+ use_cache: false
+
+compression:
+ name: QAT
+ quantization:
+ name: fp8_static
+ bits: 8
+ quant_method:
+ weight: per-tensor
+ activation: per-tensor
+ ignore_layers: ["lm_head", "embed_tokens"]
+ QAT:
+ hf_dataset: null
+ from_ptq_ckpt: ./output_ptq/qwen3-4b_fp8_static
+ training_mode: end2end
+ dist_mode: hf
+ save_format: real
+ do_train: true
+ loss_type: cakld
+ loss_topk: null
+ kd_temperature: 1.0
+ kd_alpha: 0.5
+ lm_loss_weight: 1.0
+ kd_loss_weight: 1.0
+ plugin_config:
+ enable_scale: true
+ quant_config:
+ use_weight_quant: true
+ use_activation_quant: true
+ use_qkv_quant: false
+ weight_scale_init_value: 0.1
+ activation_scale_init_value: 0.1
+ learnable:
+ act_scale: false
+ weight_scale: true
+ kv_scale: false
+ norm: false
+ lwc: false
+ hf_args:
+ bf16: true
+ logging_steps: 1
+ logging_first_step: true
+ per_device_train_batch_size: 1
+ gradient_accumulation_steps: 1
+ learning_rate: 1.0e-6
+ lr_scheduler_type: cosine
+ num_train_epochs: 1
+ max_steps: 5
+ save_strategy: "no"
+ warmup_ratio: 0.0
+ max_grad_norm: 1.0
+ gradient_checkpointing: true
+ gradient_checkpointing_kwargs:
+ use_reentrant: true
+ deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json
+
+dataset:
+ name: TextDataset
+ data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl
+ max_seq_length: 512
+ num_samples: 16
+ batch_size: 1
+ is_sft_data: true
diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md
index 2048f22d..3878481c 100644
--- a/docs/source/features/quantization/index.md
+++ b/docs/source/features/quantization/index.md
@@ -13,5 +13,6 @@ awq
gptq
fp8_lepto
qat
+qat_zero3
daq
:::
diff --git a/docs/source/features/quantization/qat_zero3.md b/docs/source/features/quantization/qat_zero3.md
new file mode 100644
index 00000000..49c8e2ff
--- /dev/null
+++ b/docs/source/features/quantization/qat_zero3.md
@@ -0,0 +1,323 @@
+# QAT + DeepSpeed ZeRO-3
+
+## 概述
+
+本文档描述 AngelSlim QAT 模块的 **DeepSpeed ZeRO-3** 支持。核心动机:当基础模型体量超过单卡显存(例如 Qwen3-30B-A3B 这类 MoE 模型)时,需要把模型参数、梯度、优化器状态都切片到多张 GPU 上,HuggingFace + DeepSpeed 的 ZeRO-3 是一套成熟的方案,但直接用在 QAT 流程里会遇到以下问题:
+
+1. HuggingFace `from_pretrained` 在 ZeRO-3 下会让每个 rank 先在 CPU 上加载**完整** state_dict,再 partition,峰值内存约 `world_size × model_size`,大模型必然 OOM。
+2. QAT 会把原始 `nn.Linear` 替换成 `QuantLinear`,又会把 fused MoE expert 拆成 per-expert `nn.Linear`。这些替换操作如果在 ZeRO-3 sharded 参数上执行,会同时触发切片 / 合并 / 重分发,逻辑极易出错、峰值不可控。
+3. QAT 的 activation scale 默认通过 forward 校准(lazy init),在 ZeRO-3 下无法运行;weight scale 的初始化依赖读取完整权重,也不可行。
+4. `convert()` 把 `QuantLinear` 转为 `QDQModule`,以及 `save_format=real` 导出压缩权重时,需要在多卡之间合并参数并保持 CPU 内存在 **rank0 一份**。
+
+针对上述问题,本实现的关键设计如下:
+
+- **每个 rank 独立构造一个空模型**(不读磁盘权重),通过 `deepspeed.zero.Init` 立即 partition,峰值内存仅 `model_size / world_size`。
+- **fused MoE 拆解发生在空模型阶段**(duck-typing 识别 `gate_up_proj / down_proj / num_experts / hidden_dim / intermediate_dim / act_fn`),新建的 per-expert `nn.Linear` 被 `deepspeed.zero.Init` 立即切片,**不需要从旧的 fused tensor 拷贝数据**。
+- **权重和 scale 都通过 safetensors 流式加载**:rank0 按文件顺序读取,其它 rank 通过 `GatheredParameters(modifier_rank=0)` 接收数据。fused MoE 的 `.experts.gate_up_proj` / `.experts.down_proj` key 会被自动切片写入每个 per-expert target。
+- **ZeRO-3 下 scale 强制从外部 PTQ checkpoint 加载**(`from_ptq_ckpt`),跳过 forward 校准;未命中的 scale 按 `weight_scale_init_value` / `activation_scale_init_value` 填充。
+- **`convert / save` 只在 rank0 构造 `QDQModule` 并直接写入合并 state_dict**,其它 rank 只参与 NCCL gather 的 collective,不持有 CPU 数据。
+
+整体流程下,非 ZeRO-3 路径的行为与 main 分支完全保持一致——所有 ZeRO-3 逻辑都收敛在一个新文件 `angelslim/utils/zero3_io.py` 中,其他文件只做少量薄调用。
+
+## 架构设计
+
+### 模块目录结构
+
+```
+angelslim/
+├── utils/
+│ └── zero3_io.py # ALL ZeRO-3 helpers: detection,
+│ # gather/scatter, empty-model build,
+│ # streaming weight/scale loaders,
+│ # consolidated save, optimizer patch.
+├── compressor/qat/
+│ ├── qat.py # ZeRO-3 branch for convert() / save()
+│ ├── modules/quantizer.py # init scales from init_value when
+│ # weight data is not accessible; dtype
+│ # alignment for DeepSpeed autocast.
+│ ├── plugins/learnable_scale.py # stream_load_scales + skip lazy_init
+│ │ # + gathered quant_inplace
+│ └── trainers/end2end_trainer.py # lm_loss + kd_loss composition with
+│ # cakld support + per-component logging
+├── data/text_dataset.py # supervise ONLY the last assistant turn
+├── models/base_model.py # from_pretrained → ZeRO-3 path
+├── engine.py # normalise ``device_map`` string
+└── utils/
+ ├── config_parser.py # QATTrainingConfig new fields
+ ├── utils.py # set_op_by_name handles string-indexed
+ └── __init__.py # re-export zero3_io helpers
+```
+
+### 执行流程(ZeRO-3 路径)
+
+```
+tools/run.py
+ └── _prewarm_hf_deepspeed_config() # register HfTrainerDeepSpeedConfig
+ └── Engine.prepare_model()
+ └── BaseLLMModel.from_pretrained() # ZeRO-3 branch
+ ├── zero3_empty_model_from_pretrained()
+ │ ├── AutoModelForCausalLM.from_config(...)
+ │ │ # triggers deepspeed.zero.Init for every Parameter
+ │ └── linearize_moe_experts_empty()
+ │ # fused Qwen3MoeExperts → empty LinearizedMoeExperts
+ └── stream_load_weights() # rank0 reads safetensors
+ └── QAT.__init__()
+ ├── init_ptq() → Qwen.replace_moe() # no-op: already linearised
+ └── register LearnableScalePlugin(from_ptq_ckpt_dir=...)
+ └── LearnableScalePlugin.before_train()
+ ├── replace nn.Linear with QuantLinear
+ │ └── Quantizer allocates scale Parameters using init_value
+ │ (no dependency on weight data)
+ └── stream_load_scales(from_ptq_ckpt) # fill weight/act/kv scales
+ └── End2EndTrainer.prepare_trainer()
+ ├── _init_optimizer() with id-deduped scale/LWC params
+ ├── patch_deepspeed_duplicate_check() # scales may be tied
+ └── HF Trainer.train()
+ # student + teacher forward → lm_loss + kd_loss composition
+ └── QAT.convert() + QAT.save()
+ # rank0-only QDQModule + consolidated state_dict → single rank write
+```
+
+## 使用方法
+
+### 前置条件
+
+1. 安装依赖:`deepspeed`、`safetensors`、`compressed-tensors`(可选,用于读取导出的 fp8 checkpoint)。
+2. 硬件:支持 NCCL 的多 GPU 节点;ZeRO-3 路径要求 `torchrun --nproc_per_node=N`。
+
+### 完整两阶段流程
+
+#### 阶段 1:PTQ 校准生成初始 scale
+
+ZeRO-3 QAT 启动时**不再**跑 forward 校准,因此必须先产出一个带 scale 的 PTQ checkpoint。使用现有 PTQ 配置(单卡即可):
+
+```bash
+python tools/run.py \
+ -c configs/qwen3/ptq/fp8_static/qwen3-a3b_fp8_static.yaml \
+ --model-path /path/to/Qwen3-30B-A3B \
+ --save-path ./output_ptq_30b
+```
+
+产出的 `./output_ptq_30b/qwen3-a3b_fp8_static/model-*.safetensors` 中包含 `.weight_scale` 和 `.input_scale`。
+
+#### 阶段 2:ZeRO-3 QAT
+
+```bash
+bash scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh
+```
+
+或直接:
+
+```bash
+torchrun --nproc_per_node=8 tools/run.py \
+ -c configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml
+```
+
+训练完成后,压缩权重会保存到 `./output/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3/final_quant_checkpoint/`。
+
+### 最小 4B 烟囱测试
+
+快速验证流程(2 张卡、5 步训练):
+
+```bash
+# Step 1: PTQ
+python tools/run.py \
+ -c configs/qwen3/ptq/fp8_static/qwen3-4b_fp8_static.yaml \
+ --model-path Qwen/Qwen3-4B \
+ --save-path ./output_ptq
+
+# Step 2: QAT
+bash scripts/qat/run_qat_for_qwen_4b_zero3.sh
+```
+
+## 配置说明
+
+### ZeRO-3 QAT 新增 / 修改字段(`compression.QAT`)
+
+| 字段 | 类型 | 默认值 | 描述 |
+|------|------|--------|------|
+| `from_ptq_ckpt` | str / null | `null` | PTQ `save_format="real"` 产出目录。ZeRO-3 下**必填**,否则 `before_train` 会报错。目录可指向最顶层(自动识别嵌套的 `final_quant_checkpoint/`)。 |
+| `lm_loss_weight` | float | `1.0` | HF CausalLM CE loss 的权重。为 0 则不计算 lm loss,也不出现在日志中。 |
+| `kd_loss_weight` | float | `0.0` | KD loss 的权重(使用 `loss_type` 选定的 KD 变体)。为 0 则不计算 KD,也不启动 teacher forward。 |
+
+注:`lm_loss_weight` / `kd_loss_weight` 任意一个 > 0 即可参与训练;两者都为 0 会在 `compute_loss` 中抛错。
+
+### KD 变体(`loss_type`)
+
+| `loss_type` | 描述 |
+|-------------|------|
+| `kl` | `KL(teacher || student)`(forward KL),per-valid-token 平均 |
+| `rkl` | `KL(student || teacher)`(reverse / backward KL),per-valid-token 平均 |
+| `mse` | student 与 teacher logits 的 MSE(per-valid-token 平均) |
+| `cakld` | Confidence-Aware KL Distillation:按 teacher 在 label 上的概率做 token-wise 的 `fkl` / `rkl` 混合,`conf * rkl + (1 - conf) * fkl` |
+| `kd` | 经典 temperature KD:`T² * KL(soft_student || soft_teacher)`,保留兼容 |
+| `kl_top_K` / `r_kl_top_K` | top-K token 上的 forward / reverse KL。`K` 可写在 `loss_type` 字符串里(例如 `kl_top_1000`),或通过 `loss_topk` 字段指定 |
+| `origin` | 纯 HF CE loss(等价于 `kd_loss_weight = 0`) |
+
+选择 `loss_type = cakld` 时的核心公式(参考 `_compute_kd_components`):
+
+```python
+# 仅在 labels != -100 的 token 上计算
+forward_kl = KL(log_softmax(student), softmax(teacher))
+backward_kl = KL(log_softmax(teacher), softmax(student))
+conf = softmax(teacher).gather(-1, label) # teacher 对目标 token 的置信度
+cakld = (conf * backward_kl + (1 - conf) * forward_kl).mean()
+```
+
+### 训练日志
+
+`QATSeq2SeqTrainer.log()` 会自动把下列指标注入 HF Trainer 的标准日志字典(所以 wandb / console / tqdm 都能看到),仅当对应权重 > 0 时才会出现:
+
+| 指标 | 含义 |
+|------|------|
+| `lm_loss` | HF CausalLM CE loss(仅对 assistant 回复位置计算,见下) |
+| `kd/` | 当前选定的 KD 主 loss(`cakld` / `kl` / ...) |
+| `kd/forward_kl` | 诊断用:`KL(teacher || student)`,始终在 `kd_loss_weight > 0` 时打印 |
+| `kd/backward_kl` | 诊断用:`KL(student || teacher)` |
+| `total_loss` | `lm_loss_weight * lm_loss + kd_loss_weight * kd/` |
+
+示例输出(30B-A3B,3 步训练):
+
+```
+{'loss': 2.08, 'grad_norm': 43.1, 'learning_rate': 1e-6,
+ 'lm_loss': 1.23, 'kd/cakld': 0.0075, 'kd/forward_kl': 0.0075, 'kd/backward_kl': 0.0075,
+ 'total_loss': 1.24, 'epoch': 0.03}
+```
+
+### Dataset:仅监督最后一个 assistant 回复
+
+对于 JSONL 格式的 SFT 数据(`messages` / `conversations` / `input+output` 三种 schema),`TextDataset._load_jsonl_data` 现在只对**最后一个 assistant 回复**位置计算 loss:
+
+- 拼接 prompt(对话中最后一个 assistant 之前的所有 turn)并通过 `apply_chat_template(..., add_generation_prompt=True)` tokenize,得到 `prompt_len`。
+- 拼接完整对话(含最后 assistant)tokenize 得到 `input_ids`,`labels = input_ids.clone()`。
+- 把 `labels[:, :prompt_len]` 和 padding 位置都置为 `-100`,HF CausalLM loss 会自动忽略。
+
+这与 HF CausalLM 内部的 shift 行为一致(`shift_logits[..., :-1]` 对齐 `shift_labels[..., 1:]`),**不需要**手动 `roll`。
+
+### DeepSpeed 配置
+
+参考 `configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json`。关键项:
+
+```json
+{
+ "bf16": {"enabled": "auto"},
+ "zero_optimization": {
+ "stage": 3,
+ "stage3_gather_16bit_weights_on_model_save": true,
+ "overlap_comm": true,
+ "contiguous_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto"
+}
+```
+
+`hf_args` 中需要设 `bf16: true`(或 `fp16: true`)显式开启混合精度,否则 DeepSpeed 的 `zero3_linear_wrap` 会默认用 fp16 autocast,而模型权重是 bf16,导致 dtype 失配。
+
+### Quantizer 初始值
+
+当 ZeRO-3 启用(或 weight 是 ZeRO-3 sharded / meta / 0 numel 的任一情况)时,`Quantizer._init_quant_params` 不再依赖权重数值,而是直接按 shape 创建 `nn.Parameter`:
+
+| 配置项 | 默认值 | 描述 |
+|------|--------|------|
+| `weight_scale_init_value` | `1.0` | Weight quantizer scale 的初始值,在 `from_ptq_ckpt` 未命中时保留 |
+| `activation_scale_init_value` | `1.0` | Activation quantizer scale 的初始值 |
+
+典型场景下 PTQ 产出的 weight scale 约为 `max(|W|) / 448 ≈ 1e-3`,建议在 yaml 里把 init value 设为 `0.1` 左右作为保底(`from_ptq_ckpt` 命中时这些值会被覆盖)。
+
+### 示例 yaml 关键片段
+
+```yaml
+compression:
+ name: QAT
+ quantization:
+ name: fp8_static
+ quant_method:
+ weight: per-tensor
+ activation: per-tensor
+ ignore_layers: ["lm_head", "embed_tokens", "gate.weight"]
+ QAT:
+ hf_dataset: null
+ from_ptq_ckpt: ./output_ptq_30b/qwen3-a3b_fp8_static
+ training_mode: end2end
+ dist_mode: hf
+ save_format: real
+ loss_type: cakld
+ lm_loss_weight: 1.0
+ kd_loss_weight: 1.0
+ plugin_config:
+ enable_scale: true
+ quant_config:
+ use_weight_quant: true
+ use_activation_quant: true
+ weight_scale_init_value: 0.1
+ activation_scale_init_value: 0.1
+ learnable:
+ act_scale: false
+ weight_scale: true
+ kv_scale: false
+ norm: false
+ lwc: false
+ hf_args:
+ bf16: true
+ per_device_train_batch_size: 1
+ learning_rate: 1.0e-6
+ gradient_checkpointing: true
+ deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json
+```
+
+完整示例:
+- `configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml`
+- `configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml`
+
+## 核心实现要点
+
+### `zero3_io.py` 中的主要 API
+
+| 函数 / 类 | 作用 |
+|-----------|------|
+| `is_deepspeed_zero3_enabled()` | 通过 HF 的 `HfTrainerDeepSpeedConfig` 弱引用判断是否已注册 ZeRO-3 配置 |
+| `is_zero3_param(p)` | 判断 Parameter 是否带有 `ds_id / ds_status / ds_numel / ds_tensor` 元数据 |
+| `gathered_param_if_zero3(p, modifier_rank=None)` | 上下文管理器;非 ZeRO-3 参数为 no-op |
+| `LinearizedMoeExperts` | 通用空 per-expert `nn.Linear` 容器,`forward` 与 HF fused 等价 |
+| `linearize_moe_experts_empty(model)` | duck-typing 扫描并原地替换,在 `deepspeed.zero.Init` 内构造以直接 partition |
+| `zero3_empty_model_from_pretrained(model_path, ...)` | `no_init_weights` + `from_config` 构造空模型 + 自动拆 MoE |
+| `stream_load_weights(model, model_path)` | 流式灌权,支持 fused MoE key 的 per-expert 切片分发 |
+| `stream_load_scales(model, ckpt_dir)` | 流式灌 scale,支持 `.weight_scale / .input_scale / .k_cache.scale / .v_cache.scale` |
+| `save_via_model_save_func(quant_model, save_func, path, prebuilt_state_dict)` | 只在 rank0 调用原 save_func,`state_dict()` 被 patch 返回合并后的字典 |
+| `patch_deepspeed_duplicate_check()` | 置空 `DeepSpeedEngine._check_for_duplicates`,允许 tied scale 参数 |
+
+### `Quantizer` 的变更
+
+- 新增 `weight_shape` 构造参数:`QuantLinear.__init__` 传入 `(out_features, in_features)`,使得 Quantizer 在不访问权重数据的情况下也能计算 scale 的形状。
+- 新增 `weight_scale_init_value` / `activation_scale_init_value` 配置项;`_init_quant_params` / `_init_lwc_params` 在 `_needs_external_weight_init(x)` 为真时使用。
+- `QuantLinear.forward` 末尾的 `F.linear` 现在包在 `torch.amp.autocast(device_type="cuda", enabled=False)` 中,并在调用前把 `input.dtype` 对齐到 `weight.dtype`,以避免 DeepSpeed `zero3_linear_wrap` 的 autocast 把 bf16 input 回转成 fp16。
+- `fake_quant` 末尾把 `out` cast 回 `x.dtype`,防止 `bf16 * fp32 = fp32` 的 dtype 泄漏。
+
+### 优化器去重
+
+`End2EndTrainer._init_optimizer` 使用 `_unique_named_params(...)` 根据 `id()` 去重收集 trainable 的 scale / LWC 参数,避免同一 Parameter 被多次加入 param group(在 MoE expert 共享 tensor 的场景下会触发)。配合 `patch_deepspeed_duplicate_check()` 即可通过 DeepSpeed 的安全检查。
+
+### `convert` + `save` 的内存控制
+
+`QAT.convert` 在 ZeRO-3 路径下:
+
+1. 对每个 `QuantLinear`:**所有 rank** 都进入 `gathered_param_if_zero3` 拿到完整 weight(NCCL collective 保持对称),但**只 rank0** 保留 CPU clone。
+2. rank0 在临时 `QDQModule` 内部跑一次 fp8/int 量化,把 `weight / weight_scale / input_scale / bias` 取出塞进 `self._rank0_state_dict`,随后丢弃临时模块。
+3. **不修改模型结构**:保持 `QuantLinear`,使得第二轮扫描(收集非 QuantLinear 参数,如 embed、lm_head、layernorm、MoE router gate)在所有 rank 上 `named_parameters` 顺序一致,collective gather 不会死锁。
+4. `QAT.save` 把 `_rank0_state_dict` 透传给 `save_via_model_save_func`,后者 patch `hf_model.state_dict`,**只 rank0** 调原 `save_func.save(...)`。
+
+rank>0 convert 阶段 CPU 峰值 ≈ 一层的完整 weight(几十 MB 到 GB 量级);rank0 峰值 ≈ 累积的合并 state_dict(完整模型大小)。
+
+## 已验证场景
+
+| 场景 | 模型 | 硬件 | 结果 |
+|------|------|------|------|
+| Dense ZeRO-3 QAT | Qwen3-4B | 2×H20 | ✓ PTQ→QAT→save 打通,产物能被 transformers 加载 |
+| MoE ZeRO-3 QAT | Qwen3-30B-A3B(48 层 × 128 experts)| 8×H20 | ✓ `stream_load_scales` 命中 37248 个 scale,训练 loss 稳定,输出 31 GB fp8 checkpoint |
+| KD + LM loss 组合 | 同上 | 同上 | ✓ `lm_loss / kd/cakld / kd/forward_kl / kd/backward_kl / total_loss` 按权重打印 |
+| 最后 assistant 仅监督 | TextDataset(jsonl) | - | ✓ 首个 valid label idx 落在 `<\|im_start\|>assistant\n` 之后 |
+| 非 ZeRO-3 回归 | Qwen3-4B 单卡 PTQ | 1×H20 | ✓ 行为与 main 一致,无回归 |
+
diff --git a/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh b/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh
new file mode 100755
index 00000000..e5b2d1c3
--- /dev/null
+++ b/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+export PYTORCH_ALLOC_CONF="expandable_segments:True"
+
+NPROC=${NPROC:-8}
+CONFIG=${CONFIG:-configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml}
+
+torchrun --nproc_per_node=${NPROC} \
+ tools/run.py \
+ -c "${CONFIG}"
diff --git a/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh b/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh
new file mode 100755
index 00000000..1b244127
--- /dev/null
+++ b/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+export PYTORCH_ALLOC_CONF="expandable_segments:True"
+
+NPROC=${NPROC:-8}
+CONFIG=${CONFIG:-configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml}
+
+torchrun --nproc_per_node=${NPROC} \
+ tools/run.py \
+ -c "${CONFIG}"
diff --git a/scripts/qat/run_qat_for_qwen_4b_zero3.sh b/scripts/qat/run_qat_for_qwen_4b_zero3.sh
new file mode 100755
index 00000000..45181701
--- /dev/null
+++ b/scripts/qat/run_qat_for_qwen_4b_zero3.sh
@@ -0,0 +1,14 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# Smoke-test launcher: Qwen3-4B with DeepSpeed ZeRO-3. Matches the full
+# 30B-A3B flow but on a model small enough to iterate quickly.
+
+export PYTORCH_ALLOC_CONF="expandable_segments:True"
+
+NPROC=${NPROC:-2}
+CONFIG=${CONFIG:-configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml}
+
+torchrun --nproc_per_node=${NPROC} \
+ tools/run.py \
+ -c "${CONFIG}"
diff --git a/tools/run.py b/tools/run.py
index c4135199..88d84451 100644
--- a/tools/run.py
+++ b/tools/run.py
@@ -112,6 +112,7 @@ def multi_nodes_run(config):
shuffle=dataset_config.shuffle,
inference_settings=dataset_config.inference_settings,
use_audio_in_video=model_config.use_audio_in_video,
+ is_sft_data=dataset_config.is_sft_data,
)
# Step 6: Initialize compressor
@@ -267,6 +268,32 @@ def weight_only_run(config):
)
+def _prewarm_hf_deepspeed_config(config):
+ """Pre-construct ``Seq2SeqTrainingArguments`` so HF's
+ ``HfTrainerDeepSpeedConfig`` weak-ref is registered BEFORE
+ ``from_pretrained`` runs. That is what flips
+ ``is_deepspeed_zero3_enabled()`` to True and makes our
+ ``BaseLLMModel.from_pretrained`` take the ZeRO-3 path.
+
+ Returns the constructed TrainingArguments (kept alive via the caller's
+ local variable) or None if not applicable.
+ """
+ compress_cfg = getattr(config, "compression_config", None)
+ qat_cfg = getattr(compress_cfg, "QAT", None) if compress_cfg is not None else None
+ hf_args = getattr(qat_cfg, "hf_args", None) if qat_cfg is not None else None
+ if not hf_args or not hf_args.get("deepspeed"):
+ return None
+
+ from transformers import Seq2SeqTrainingArguments
+
+ trainer_args = Seq2SeqTrainingArguments(
+ output_dir=config.global_config.save_path,
+ **hf_args,
+ )
+ print_info("[DeepSpeed pre-warm] HfTrainerDeepSpeedConfig registered before model load.")
+ return trainer_args
+
+
def run(config):
"""
Run the LLM compression process based on the provided configuration.
@@ -295,6 +322,12 @@ def run(config):
weight_only_run(config)
return
+ # QAT + DeepSpeed: register HfTrainerDeepSpeedConfig BEFORE loading the
+ # model so ``from_pretrained`` takes the ZeRO-3 path. No-op otherwise.
+ # The returned object must stay alive until after the model is built
+ # because HF's weak-ref mechanism drops the config otherwise.
+ _hf_ds_args = _prewarm_hf_deepspeed_config(config)
+
# Step 2: Execute complete pipeline
slim_engine = Engine()
@@ -312,6 +345,9 @@ def run(config):
attn_implementation=model_config.attn_implementation,
deploy_backend=global_config.deploy_backend,
)
+ # Safe to release now: the model is built and any deepspeed.zero.Init
+ # effects have already happened on all parameters.
+ del _hf_ds_args
# Step 4: Prepare data (optional custom dataloader)
if compress_config.need_dataset:
@@ -327,6 +363,7 @@ def run(config):
use_audio_in_video=model_config.use_audio_in_video,
model_name=model_config.name,
quantization_config=compress_config.quantization,
+ is_sft_data=dataset_config.is_sft_data,
)
# Step 5: Initialize compressor