From d6a5848cd0877904e1e7dbfcd6416fdfe69dbc1c Mon Sep 17 00:00:00 2001 From: root Date: Sun, 26 Apr 2026 15:06:05 +0800 Subject: [PATCH 01/13] add deepspeed QAT --- angelslim/compressor/qat/modules/quantizer.py | 113 ++- .../compressor/qat/plugins/learnable_scale.py | 57 +- angelslim/compressor/qat/qat.py | 56 +- .../qat/trainers/end2end_trainer.py | 70 +- angelslim/engine.py | 7 + angelslim/models/base_model.py | 29 +- angelslim/utils/__init__.py | 13 + angelslim/utils/config_parser.py | 6 + angelslim/utils/utils.py | 14 +- angelslim/utils/zero3_io.py | 797 ++++++++++++++++++ .../learn_scale/ds_config_zero3.json | 45 + ..._fp8_static_end2end_learn_scale_zero3.yaml | 81 ++ ..._fp8_static_end2end_learn_scale_zero3.yaml | 69 ++ scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh | 11 + scripts/qat/run_qat_for_qwen_4b_zero3.sh | 14 + scripts/qat/test_moe_zero3_build.py | 98 +++ tools/run.py | 37 + 17 files changed, 1471 insertions(+), 46 deletions(-) create mode 100644 angelslim/utils/zero3_io.py create mode 100644 configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json create mode 100644 configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml create mode 100644 configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml create mode 100755 scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh create mode 100755 scripts/qat/run_qat_for_qwen_4b_zero3.sh create mode 100644 scripts/qat/test_moe_zero3_build.py diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py index bc81e998..e0493709 100644 --- a/angelslim/compressor/qat/modules/quantizer.py +++ b/angelslim/compressor/qat/modules/quantizer.py @@ -18,6 +18,8 @@ import torch.nn as nn import torch.nn.functional as F +from ....utils import is_deepspeed_zero3_enabled, is_zero3_param + FP8_E4M3_QMIN = -448 FP8_E4M3_QMAX = 448 @@ -49,10 +51,30 @@ def _parse_bits_and_dtype(qtype_str): class Quantizer(nn.Module): - def __init__(self, config, quant_info, x=None, is_act=False, resume=False, num_heads=-1): + def __init__( + self, + config, + quant_info, + x=None, + is_act=False, + resume=False, + num_heads=-1, + weight_shape=None, + ): super().__init__() self.is_act = is_act self.num_heads = num_heads + # ``weight_shape`` lets the caller pre-declare the (out_features, + # in_features) of the parent Linear so we can size weight-side + # quantizer Parameters without ever touching the (possibly ZeRO-3 + # sharded) weight tensor. + self.weight_shape = ( + (int(weight_shape[0]), int(weight_shape[1])) if weight_shape is not None else None + ) + # Configurable initial values used when ZeRO-3 is active and we + # cannot depend on the weight data. + self.weight_scale_init_value = float(config.get("weight_scale_init_value", 1.0)) + self.activation_scale_init_value = float(config.get("activation_scale_init_value", 1.0)) info = quant_info.quant_algo_info["w"] self.group_size = quant_info.quant_algo_info.get("w_group_size", -1) rewrite_conf = config.get("weight", {}) @@ -117,8 +139,21 @@ def _init_quant_params(self, x): self.scale = self.zero_point = None if self.resume: self.init = True - zp = torch.empty(1) if not self.is_sym else None - self._set_quant_parameters(torch.empty(1), zp) + init_val = self.activation_scale_init_value + scale = torch.full((1,), init_val, dtype=torch.float32) + zp = torch.zeros(1, dtype=torch.float32) if not self.is_sym else None + self._set_quant_parameters(scale, zp) + return + + # Weight-side path. If we cannot use ``x`` (ZeRO-3 sharded, + # meta, or simply not provided), allocate Parameters by shape + # and ``weight_scale_init_value``. + if self._needs_external_weight_init(x): + shape = self._weight_scale_shape_from_meta() + init_val = self.weight_scale_init_value + scale = torch.full(shape, init_val, dtype=torch.float32) + zp = torch.zeros(shape, dtype=torch.float32) if not self.is_sym else None + self._set_quant_parameters(scale, zp) return if self.is_sym: @@ -131,6 +166,52 @@ def _init_quant_params(self, x): ) self._set_quant_parameters(scale, zp.round()) + def _needs_external_weight_init(self, x): + """True when weight-side init must skip data-dependent computation + and instead allocate Parameters from shape + init_value. + + Triggered by: + * DeepSpeed ZeRO-3 active (HF integration registered) + * ``x`` is a ZeRO-3 sharded Parameter + * ``x`` is None / on meta device / empty + """ + if is_deepspeed_zero3_enabled(): + return True + if x is None: + return True + if is_zero3_param(x): + return True + if hasattr(x, "device") and x.device.type == "meta": + return True + if hasattr(x, "numel") and x.numel() == 0: + return True + return False + + def _weight_2d_shape(self): + """Resolve (out_features, in_features) for the underlying Linear. + Callers must have passed ``weight_shape`` via ``QuantLinear``.""" + if self.weight_shape is not None: + return self.weight_shape + raise RuntimeError( + "Quantizer needs ``weight_shape`` to size weight scale without a " + "concrete tensor (set in QuantLinear.__init__)." + ) + + def _weight_scale_shape_from_meta(self): + out_dim, in_dim = self._weight_2d_shape() + if self.granularity == "per-channel": + return (out_dim, 1) + if self.granularity == "per-group": + if not self.group_size or self.group_size <= 0: + raise ValueError("per-group quantization requires positive group_size.") + if in_dim % self.group_size != 0: + raise ValueError( + f"dim 1 ({in_dim}) not divisible by group_size ({self.group_size})" + ) + return (out_dim, in_dim // self.group_size) + # per-tensor and any reduce-to-scalar variant + return (1,) + def _init_lwc_params(self, x, config): lwc_cfg = config.get("lwc", {}) if isinstance(lwc_cfg, dict): @@ -141,11 +222,18 @@ def _init_lwc_params(self, x, config): self.lwc_init_value = 4.0 if self.lwc: - if x.dim() != 2: - x_for_shape = x.flatten(1) + # Resolve (out_dim, in_dim) without depending on ``x`` data. + if self._needs_external_weight_init(x): + out_dim, in_dim = self._weight_2d_shape() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: - x_for_shape = x - out_dim, in_dim = x_for_shape.shape + if x.dim() != 2: + x_for_shape = x.flatten(1) + else: + x_for_shape = x + out_dim, in_dim = x_for_shape.shape + device = x.device + if self.granularity == "per-group": if not self.group_size or self.group_size <= 0: raise ValueError("per-group quantization requires positive group_size.") @@ -158,7 +246,7 @@ def _init_lwc_params(self, x, config): dim1 = 1 init = ( - torch.ones((dim1, 1), device=x.device, dtype=torch.float32) * self.lwc_init_value + torch.ones((dim1, 1), device=device, dtype=torch.float32) * self.lwc_init_value ) self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True) self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True) @@ -516,8 +604,15 @@ def __init__( self.register_parameter("bias", org_module.bias) self.use_weight_quant = use_weight_quant self.use_act_quant = use_act_quant + # Under ZeRO-3 the weight Parameter ``org_module.weight`` may be a + # zero-numel shard. Pass an explicit (out, in) shape so the weight + # quantizer can size its scale Parameter from the Linear shape + # rather than inspecting the (possibly sharded) tensor. + weight_shape = (org_module.out_features, org_module.in_features) if self.use_weight_quant: - self.weight_quantizer = Quantizer(config, quant_info, x=org_module.weight) + self.weight_quantizer = Quantizer( + config, quant_info, x=org_module.weight, weight_shape=weight_shape, + ) if self.use_act_quant: self.act_quantizer = Quantizer(config, quant_info, is_act=True, resume=resume) diff --git a/angelslim/compressor/qat/plugins/learnable_scale.py b/angelslim/compressor/qat/plugins/learnable_scale.py index 8f01514f..99521691 100644 --- a/angelslim/compressor/qat/plugins/learnable_scale.py +++ b/angelslim/compressor/qat/plugins/learnable_scale.py @@ -15,7 +15,13 @@ import torch from tqdm import tqdm -from ....utils import print_info, set_op_by_name +from ....utils import ( + gathered_params_if_zero3, + is_deepspeed_zero3_enabled, + print_info, + set_op_by_name, + stream_load_scales, +) from ..modules.quantizer import QuantLinear from .base_plugin import BasePlugin from .plugin_manager import PluginManager @@ -29,11 +35,22 @@ @PluginManager.plugin("learnable_scale") class LearnableScalePlugin(BasePlugin): - def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, **kwargs): + def __init__( + self, + quant_info=None, + ignore_layers=None, + resume_ckpt_dir=None, + from_ptq_ckpt_dir=None, + **kwargs, + ): super().__init__(**kwargs) self.quant_info = quant_info self.ignore_layers = ignore_layers self.resume_ckpt_dir = resume_ckpt_dir + # Optional warm-start from a PTQ "real" checkpoint (only scales are + # read; base weights stay as loaded by from_pretrained). Required + # under DeepSpeed ZeRO-3. + self.from_ptq_ckpt_dir = from_ptq_ckpt_dir self.use_weight_quant = self.config.get("use_weight_quant", False) self.use_activation_quant = self.config.get("use_activation_quant", False) self.fp8_attn = self.config.get("fp8_attn", False) @@ -47,9 +64,25 @@ def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, ** self.learn_norm = learnable_cfg.get("norm", False) def before_train(self, **kwargs): + zero3 = is_deepspeed_zero3_enabled() + if zero3 and not self.from_ptq_ckpt_dir: + raise ValueError( + "DeepSpeed ZeRO-3 QAT requires `compression.QAT.from_ptq_ckpt` " + "to warm-start scales (lazy_init via forward is impossible " + "on sharded weights)." + ) + # Retrieve KV head count from model config for per-head quantization model_config = getattr(self.quant_model.model, "config", None) num_kv_heads = getattr(model_config, "num_key_value_heads", -1) + # Pre-allocate ``act_quantizer.scale`` as a Parameter whenever we + # plan to fill it from a checkpoint (full resume OR PTQ warm-start + # OR ZeRO-3 — where lazy_init is impossible). + act_preallocate = ( + self.resume_ckpt_dir is not None + or self.from_ptq_ckpt_dir is not None + or zero3 + ) for name, module in self.quant_model.model.named_modules(): if isinstance(module, torch.nn.Linear): if any(ig in name for ig in self.ignore_layers): @@ -67,7 +100,7 @@ def before_train(self, **kwargs): self.quant_info, self.use_weight_quant, self.use_activation_quant, - resume=self.resume_ckpt_dir is not None, + resume=act_preallocate, qkv_config=qkv_cfg, ) set_op_by_name(self.quant_model.model, name, q_linear) @@ -78,10 +111,18 @@ def before_train(self, **kwargs): print_info(self.quant_model.model) + # Warm-start scales from a previous PTQ "real" checkpoint. Only + # quantizer Parameters are touched; base Linear weights are NOT + # overwritten. + if self.from_ptq_ckpt_dir is not None: + stream_load_scales(self.quant_model.model, self.from_ptq_ckpt_dir) + if ( self.use_activation_quant and not q_linear.act_quantizer.dynamic and self.resume_ckpt_dir is None + and not zero3 + and self.from_ptq_ckpt_dir is None ): self._lazy_init(**kwargs) @@ -284,5 +325,13 @@ def _get_qkv_config_for_layer(name, quant_config): @torch.no_grad() def quant_inplace(model): for _, module in model.named_modules(): - if isinstance(module, QuantLinear): + if not isinstance(module, QuantLinear): + continue + # Gather the weight together with all weight_quantizer Parameters + # (scale / zero_point / optional LWC clip factors) so the + # fake-quant runs on the full materialised tensor under ZeRO-3. + params = [module.weight] + if hasattr(module, "weight_quantizer"): + params.extend(module.weight_quantizer.parameters(recurse=True)) + with gathered_params_if_zero3(params, modifier_rank=None): module.weight.data = module.weight_quantizer(module.weight.data) diff --git a/angelslim/compressor/qat/qat.py b/angelslim/compressor/qat/qat.py index a4274595..10e7fc4b 100644 --- a/angelslim/compressor/qat/qat.py +++ b/angelslim/compressor/qat/qat.py @@ -17,7 +17,13 @@ import torch from safetensors.torch import save_file -from ...utils import print_info, set_op_by_name +from ...utils import ( + gathered_param_if_zero3, + model_has_zero3_params, + print_info, + save_via_model_save_func, + set_op_by_name, +) from ..compressor_factory import CompressorFactory from ..quant.modules.helper_layer import QDQModule from .modules.quantizer import QuantLinear @@ -57,6 +63,7 @@ def _init_plugins(self): quant_info=self.quant_info, ignore_layers=self.config["compress_config"].quantization.ignore_layers, resume_ckpt_dir=self.config["compress_config"].QAT.resume_ckpt_dir, + from_ptq_ckpt_dir=self.config["compress_config"].QAT.from_ptq_ckpt, config=self.plugin_config.get("quant_config", {}), quant_model=self.quant_model, ) @@ -72,6 +79,14 @@ def _init_trainer(self): def run(self, dataloader): self.trainer.run(dataloader) + @staticmethod + def _gather_clone(tensor): + """Detach + CPU-clone a tensor, gathering if it is a ZeRO-3 shard.""" + if tensor is None: + return None + with gathered_param_if_zero3(tensor): + return tensor.detach().cpu().clone() + def convert(self): if self.save_fmt not in ("real", "real_and_kvcache"): return @@ -86,24 +101,28 @@ def convert(self): ] for name, module in quant_linear_modules: + weight = self._gather_clone(module.weight) + bias = self._gather_clone(getattr(module, "bias", None)) + weight_scale = None if hasattr(module, "weight_quantizer"): - weight_scale = module.weight_quantizer.scale.data.clone() + weight_scale = self._gather_clone(module.weight_quantizer.scale) input_scale = None if module.use_act_quant and hasattr(module, "act_quantizer"): act_quantizer = module.act_quantizer if hasattr(act_quantizer, "scale") and act_quantizer.scale is not None: - input_scale = act_quantizer.scale.data.clone() + input_scale = self._gather_clone(act_quantizer.scale) qdq_module = QDQModule( quant_algo=quant_algo, - weight=module.weight, + weight=weight, weight_scale=weight_scale, - bias=module.bias, + bias=bias, group_size=( module.weight_quantizer.group_size - if hasattr(module.weight_quantizer, "group_size") + if hasattr(module, "weight_quantizer") + and hasattr(module.weight_quantizer, "group_size") else 128 ), input_scale=input_scale, @@ -126,7 +145,18 @@ def _save_kv_cache_scales(self, save_path: str): else: continue scale_key = f"{cache_name}.scale" - kv_scales[scale_key] = module.qkv_quantizer.scale.data.clone().float().cpu() + scale_tensor = self._gather_clone(module.qkv_quantizer.scale) + if scale_tensor is not None: + kv_scales[scale_key] = scale_tensor.float() + + if model_has_zero3_params(self.quant_model.model): + # Only rank0 writes the file. + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) + if rank != 0: + return os.makedirs(save_path, exist_ok=True) out_file = os.path.join(save_path, "kv_cache_scales.safetensors") @@ -146,7 +176,11 @@ def save(self, save_path: str): # "real": save real-quant model via model-specific save function elif self.save_fmt == "real": save_func = self.quant_model.get_save_func()(self.quant_model) - save_func.save(os.path.join(save_path, "final_quant_checkpoint")) + save_via_model_save_func( + self.quant_model, + save_func, + os.path.join(save_path, "final_quant_checkpoint"), + ) # "save_kvcache_only": only export KV cache scales (kv_cache_scales.safetensors) elif self.save_fmt == "save_kvcache_only": @@ -155,7 +189,11 @@ def save(self, save_path: str): # "real_and_kvcache": save real-quant model AND KV cache scales elif self.save_fmt == "real_and_kvcache": save_func = self.quant_model.get_save_func()(self.quant_model) - save_func.save(os.path.join(save_path, "final_quant_checkpoint")) + save_via_model_save_func( + self.quant_model, + save_func, + os.path.join(save_path, "final_quant_checkpoint"), + ) self._save_kv_cache_scales(os.path.join(save_path, "final_quant_checkpoint")) else: diff --git a/angelslim/compressor/qat/trainers/end2end_trainer.py b/angelslim/compressor/qat/trainers/end2end_trainer.py index fb5daf6d..40387b87 100644 --- a/angelslim/compressor/qat/trainers/end2end_trainer.py +++ b/angelslim/compressor/qat/trainers/end2end_trainer.py @@ -18,11 +18,28 @@ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments from ....data.qat_dataset import QATDataset -from ....utils import print_info +from ....utils import patch_deepspeed_duplicate_check, print_info from ..plugins.learnable_scale import set_quant_state from .trainer_factory import TrainerFactory +def _unique_named_params(model, predicate): + """Collect parameters matching ``predicate`` with id-based de-duplication. + + Some QAT setups share a single scale Parameter across multiple + QuantLinear views (e.g. MoE experts built from a shared tensor). HF / + DeepSpeed optimizer init rejects duplicates, so we de-dup by ``id``. + """ + seen = set() + result = [] + for name, param in model.named_parameters(): + if id(param) in seen or not predicate(name, param): + continue + seen.add(id(param)) + result.append(param) + return result + + class QATSeq2SeqTrainer(Seq2SeqTrainer): def __init__(self, *args, loss_config=None, quant_config=None, **kwargs): super().__init__(*args, **kwargs) @@ -173,13 +190,13 @@ def __init__(self, quant_model, config, plugin_manager): def _init_optimizer(self): lr = float(self.config["compress_config"].QAT.hf_args.get("learning_rate", 1e-5)) wd = float(self.config["compress_config"].QAT.hf_args.get("weight_decay", 0)) + scale_params = _unique_named_params( + self.quant_model.model, + lambda n, p: p.requires_grad and ("scale" in n or "zero_point" in n), + ) params = [ { - "params": [ - p - for n, p in self.quant_model.model.named_parameters() - if "scale" in n or "zero_point" in n - ], + "params": scale_params, "weight_decay": wd, "lr": lr, } @@ -191,6 +208,7 @@ def _init_optimizer(self): .get("lwc", {}) .get("enable_lwc", False) ) + lwc_param_count = 0 if enable_lwc: lwc_lr = float( self.config["compress_config"] @@ -198,30 +216,42 @@ def _init_optimizer(self): .get("lwc", {}) .get("lwc_lr", 1e-1) ) - lwc_params = [ - { - "params": [ - p - for n, p in self.quant_model.model.named_parameters() - if "clip_factor_w_max" in n or "clip_factor_w_min" in n - ], - "weight_decay": wd, - "lr": lwc_lr, - } - ] - params.extend(lwc_params) + lwc_params = _unique_named_params( + self.quant_model.model, + lambda n, p: p.requires_grad + and ("clip_factor_w_max" in n or "clip_factor_w_min" in n), + ) + lwc_param_count = len(lwc_params) + params.append( + {"params": lwc_params, "weight_decay": wd, "lr": lwc_lr} + ) + + if not any(group["params"] for group in params): + raise ValueError("QAT optimizer has no trainable parameters.") self.optimizer = torch.optim.AdamW(params) if enable_lwc: - print_info(f"Init optimizer with learnable lr={lr} lwc_lr={lwc_lr} weight_decay={wd}") + print_info( + f"Init optimizer with {len(scale_params)} scale params, " + f"{lwc_param_count} lwc params, lr={lr} lwc_lr={lwc_lr} weight_decay={wd}" + ) else: - print_info(f"Init optimizer with learnable lr={lr} weight_decay={wd}") + print_info( + f"Init optimizer with {len(scale_params)} scale params, " + f"lr={lr} weight_decay={wd}" + ) def prepare_trainer(self): if self.training_mode == "blockwise": return if self.training_mode == "end2end" and self.dist_mode == "hf": self._init_optimizer() + # When DeepSpeed is used, neutralize its duplicate-parameter + # check: it rejects param-groups that share tensors, which our + # scale/zero_point setup can legally have (shared tensors + # across views). Idempotent and a no-op if deepspeed is absent. + if self.config["compress_config"].QAT.hf_args.get("deepspeed") is not None: + patch_deepspeed_duplicate_check() self.external_trainer = QATSeq2SeqTrainer( model=self.quant_model.model, processing_class=self.quant_model.tokenizer, diff --git a/angelslim/engine.py b/angelslim/engine.py index 757d8741..7feee90f 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -101,6 +101,13 @@ def prepare_model( assert model_name, "model_name must be specified." assert model_path, "model_path must be specified." + # Normalize device_map for DeepSpeed ZeRO / distributed training: YAML + # configs often write ``None`` / ``"None"`` / ``"distributed"`` to + # mean "no pre-placement, let DeepSpeed shard". HF only accepts + # Python ``None`` there. + if isinstance(device_map, str) and device_map.lower() in ("none", "distributed"): + device_map = None + # Initialize slim model by ModelFactory self.slim_model = SlimModelFactory.create( model_name, model=model, deploy_backend=deploy_backend diff --git a/angelslim/models/base_model.py b/angelslim/models/base_model.py index 3f15befb..39ee6ece 100644 --- a/angelslim/models/base_model.py +++ b/angelslim/models/base_model.py @@ -24,7 +24,13 @@ from ..compressor.quant.core import QuantConfig from ..compressor.quant.modules import NVFP4QDQModule, QDQModule -from ..utils import common_prefix, print_info +from ..utils import ( + common_prefix, + is_deepspeed_zero3_enabled, + print_info, + stream_load_weights, + zero3_empty_model_from_pretrained, +) __all__ = ["BaseLLMModel"] @@ -65,6 +71,27 @@ def from_pretrained( using_multi_nodes=False, attn_implementation="default", ): + # DeepSpeed ZeRO-3 path: build an empty sharded model on every rank + # via deepspeed.zero.Init, linearize fused MoE experts, then stream + # the safetensors checkpoint into the (sharded) parameters. This + # avoids HF's path that materialises the full state_dict on every + # rank's CPU before sharding. + if is_deepspeed_zero3_enabled(): + log_prefix = f"[{type(self).__name__}.from_pretrained]" + self.model = zero3_empty_model_from_pretrained( + model_path, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + use_cache=use_cache, + attn_implementation=attn_implementation, + log_prefix=log_prefix, + ) + stream_load_weights(self.model, model_path, log_prefix=log_prefix) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + return + kwargs = dict( torch_dtype=torch_dtype, device_map=device_map, diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index d80be244..46e1c0d9 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -30,3 +30,16 @@ from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 +from .zero3_io import consolidated_state_dict # noqa: F401 +from .zero3_io import gathered_param_if_zero3 # noqa: F401 +from .zero3_io import gathered_params_if_zero3 # noqa: F401 +from .zero3_io import is_deepspeed_zero3_enabled # noqa: F401 +from .zero3_io import is_zero3_param # noqa: F401 +from .zero3_io import linearize_moe_experts_empty # noqa: F401 +from .zero3_io import LinearizedMoeExperts # noqa: F401 +from .zero3_io import model_has_zero3_params # noqa: F401 +from .zero3_io import patch_deepspeed_duplicate_check # noqa: F401 +from .zero3_io import save_via_model_save_func # noqa: F401 +from .zero3_io import stream_load_scales # noqa: F401 +from .zero3_io import stream_load_weights # noqa: F401 +from .zero3_io import zero3_empty_model_from_pretrained # noqa: F401 diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index 77481b9e..926284bc 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -272,6 +272,12 @@ class QATTrainingConfig: hf_dataset: Optional[str] = None do_train: bool = field(default=True) resume_ckpt_dir: Optional[str] = None + # Optional warm-start directory (a previous ``save_fmt="real"`` output). + # When set, scales / zero_points / kv-cache scales are loaded from this + # directory into the freshly-created ``QuantLinear`` quantizers; base + # ``Linear`` weights still come from ``model.model_path``. Required when + # DeepSpeed ZeRO-3 is enabled (no calibration is possible on shards). + from_ptq_ckpt: Optional[str] = None loss_type: str = field(default="origin") loss_topk: Optional[int] = None kd_temperature: float = field(default=1.0) diff --git a/angelslim/utils/utils.py b/angelslim/utils/utils.py index 36422758..e10f2e98 100644 --- a/angelslim/utils/utils.py +++ b/angelslim/utils/utils.py @@ -49,10 +49,18 @@ def set_op_by_name(layer, name, new_module): if len(levels) > 1: mod_ = layer for l_idx in range(len(levels) - 1): - if levels[l_idx].isdigit(): - mod_ = mod_[int(levels[l_idx])] + part = levels[l_idx] + if part.isdigit(): + # Prefer integer indexing for nn.ModuleList / nn.Sequential; + # fall back to getattr for custom containers (e.g. our + # LinearizedMoeExperts that registers experts via + # ``setattr(self, str(idx), ...)``). + try: + mod_ = mod_[int(part)] + except (TypeError, IndexError, KeyError): + mod_ = getattr(mod_, part) else: - mod_ = getattr(mod_, levels[l_idx]) + mod_ = getattr(mod_, part) setattr(mod_, levels[-1], new_module) else: setattr(layer, name, new_module) diff --git a/angelslim/utils/zero3_io.py b/angelslim/utils/zero3_io.py new file mode 100644 index 00000000..37e5b9b6 --- /dev/null +++ b/angelslim/utils/zero3_io.py @@ -0,0 +1,797 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All-in-one DeepSpeed ZeRO-3 helpers for AngelSlim QAT. + +This module concentrates everything ZeRO-3-specific so that the rest of the +codebase touches only a few thin call-sites: + + * ZeRO-3 detection / parameter gathering helpers. + * Empty model construction under ``deepspeed.zero.Init`` plus a generic + "linearize fused MoE experts" pass that builds *empty* per-expert + ``nn.Linear`` modules (no copy from the old fused tensor required — + weights are filled later from the safetensors checkpoint). + * Streaming weight loader: walks the safetensors shards once, on rank 0 + only, and broadcasts each tensor into the (possibly sharded) target + parameter via ``GatheredParameters(modifier_rank=0)``. Handles fused + MoE keys (``...experts.gate_up_proj``, ``...experts.down_proj``) by + slicing per expert into the linearized targets. + * Streaming scale loader for QAT warm-start: reads ``*.weight_scale``, + ``*.input_scale``, ``*.k_cache.scale``, ``*.v_cache.scale`` keys from a + PTQ "real" checkpoint and writes them into the freshly-created + ``QuantLinear`` quantizer parameters. + * Saving: gather a sharded model into a rank-0 CPU state_dict and call + the model-specific save_func by patching ``state_dict``. + * Optimizer-side patches needed because QAT scale parameters are tied + across multiple modules in a layer. + +By design, **nothing in this file mutates the model when ZeRO-3 is not +enabled** (each helper is a no-op or behaves identically to the +non-distributed path). Importing this module is therefore safe in any +configuration. +""" + +from __future__ import annotations + +import gc +import glob +import json +import os +from contextlib import contextmanager, nullcontext +from typing import Optional + +import torch +import torch.nn as nn + +from .lazy_imports import deepspeed +from .utils import find_parent_layer_and_sub_name, print_info + +ZERO3_PARAM_ATTRS = ("ds_id", "ds_status", "ds_numel", "ds_tensor") + + +# --------------------------------------------------------------------------- +# Basic detection / context helpers +# --------------------------------------------------------------------------- + + +def is_deepspeed_zero3_enabled() -> bool: + """True iff HuggingFace's ``HfTrainerDeepSpeedConfig`` is registered with + ZeRO stage 3. Returns False if ``transformers``/``deepspeed`` is not + importable.""" + try: + from transformers.integrations.deepspeed import ( + is_deepspeed_zero3_enabled as _hf, + ) + + return bool(_hf()) + except Exception: # noqa: BLE001 + return False + + +def is_zero3_param(x) -> bool: + """True iff ``x`` is a ZeRO-3 sharded parameter (``deepspeed.zero.Init`` + has injected its bookkeeping attributes).""" + if not isinstance(x, torch.nn.Parameter): + return False + return any(hasattr(x, attr) for attr in ZERO3_PARAM_ATTRS) + + +@contextmanager +def gathered_param_if_zero3(x, modifier_rank: Optional[int] = None): + """All-gather a ZeRO-3 shard for the lifetime of the block. + + Pure no-op (yields ``x`` unchanged) when ``x`` is not a ZeRO-3 shard, so + callers can always wrap their critical sections without branching. + """ + if is_zero3_param(x): + ctx = deepspeed.zero.GatheredParameters([x], modifier_rank=modifier_rank) + else: + ctx = nullcontext() + with ctx: + yield x + + +@contextmanager +def gathered_params_if_zero3(params, modifier_rank: Optional[int] = None): + """Batched variant of :func:`gathered_param_if_zero3`.""" + params = [p for p in params if p is not None] + z3 = [p for p in params if is_zero3_param(p)] + if z3: + ctx = deepspeed.zero.GatheredParameters(z3, modifier_rank=modifier_rank) + else: + ctx = nullcontext() + with ctx: + yield params + + +def model_has_zero3_params(model) -> bool: + return any(is_zero3_param(p) for p in model.parameters()) + + +def _rank() -> int: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return 0 + + +def _cleanup() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Empty MoE linearization +# --------------------------------------------------------------------------- + + +class LinearizedMoeExperts(nn.Module): + """Empty per-expert ``nn.Linear`` container for fused MoE experts. + + This module mirrors the ``forward`` of HuggingFace's fused experts but + holds parameters as ``num_experts`` triplets of ``(gate_proj, up_proj, + down_proj)`` ``nn.Linear`` modules. Construction is **purely structural** + — no weight copy from the old fused tensor — so it is safe to instantiate + under ``deepspeed.zero.Init`` and let the streaming loader fill in the + weights afterwards. + """ + + _angelslim_linearized_moe = True + + def __init__( + self, + num_experts: int, + hidden_dim: int, + intermediate_dim: int, + act_fn, + dtype=torch.bfloat16, + device=None, + config=None, + ): + super().__init__() + self.num_experts = int(num_experts) + self.hidden_dim = int(hidden_dim) + self.intermediate_dim = int(intermediate_dim) + self.act_fn = act_fn + if config is not None: + self.config = config + + if device is None or (isinstance(device, torch.device) and device.type == "meta"): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for expert_idx in range(self.num_experts): + expert = nn.ModuleDict( + { + "gate_proj": nn.Linear( + self.hidden_dim, self.intermediate_dim, + bias=False, dtype=dtype, device=device, + ), + "up_proj": nn.Linear( + self.hidden_dim, self.intermediate_dim, + bias=False, dtype=dtype, device=device, + ), + "down_proj": nn.Linear( + self.intermediate_dim, self.hidden_dim, + bias=False, dtype=dtype, device=device, + ), + } + ) + setattr(self, str(expert_idx), expert) + + def __getitem__(self, idx): + return getattr(self, str(idx)) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot( + top_k_index, num_classes=self.num_experts + ) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + expert_layer = getattr(self, str(int(expert_idx))) + gate = expert_layer["gate_proj"](current_state) + up = expert_layer["up_proj"](current_state) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = expert_layer["down_proj"](current_hidden_states) + current_hidden_states = ( + current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + ) + final_hidden_states.index_add_( + 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) + ) + + return final_hidden_states + + +def _is_fused_moe_experts(module) -> bool: + """Heuristic: matches HF Qwen3MoeExperts / HYV3Experts / similar. + + They all expose ``gate_up_proj`` / ``down_proj`` as ``nn.Parameter`` + plus ``num_experts``, ``hidden_dim``, ``intermediate_dim``, ``act_fn``.""" + if isinstance(module, LinearizedMoeExperts): + return False + required = ("gate_up_proj", "down_proj", "num_experts", "hidden_dim", + "intermediate_dim", "act_fn") + return all(hasattr(module, a) for a in required) + + +def _ds_full_shape(param): + """Full shape of a parameter, accounting for ZeRO-3 sharding.""" + shape = getattr(param, "ds_shape", None) + if shape is None: + shape = param.shape + return tuple(int(x) for x in shape) + + +def linearize_moe_experts_empty(model, dtype=torch.bfloat16) -> int: + """Replace every fused MoE experts module in ``model`` with an empty + :class:`LinearizedMoeExperts`. + + Under ZeRO-3 the new ``nn.Linear`` parameters are created inside a + ``deepspeed.zero.Init`` context so they get partitioned immediately. + Weights are NOT copied from the old fused tensors — the streaming + safetensors loader is responsible for that. The old module is dropped + afterwards. + """ + targets = [] + for name, module in model.named_modules(): + if _is_fused_moe_experts(module): + targets.append(name) + + if not targets: + return 0 + + z3 = is_deepspeed_zero3_enabled() + replaced = 0 + for name in targets: + parent, sub_name = find_parent_layer_and_sub_name(model, name) + old = getattr(parent, sub_name) + + # Resolve dimensions from the (possibly sharded) old tensors. + gate_up_shape = _ds_full_shape(old.gate_up_proj) + down_shape = _ds_full_shape(old.down_proj) + num_experts = int(old.num_experts) + hidden_dim = int(old.hidden_dim) + # gate_up_proj: [num_experts, 2*intermediate_dim, hidden_dim] + if len(gate_up_shape) >= 3 and gate_up_shape[-2] % 2 == 0: + intermediate_dim = gate_up_shape[-2] // 2 + elif len(down_shape) >= 3: + intermediate_dim = int(down_shape[-1]) + else: + intermediate_dim = int(old.intermediate_dim) + + old_dtype = old.gate_up_proj.dtype if hasattr(old.gate_up_proj, "dtype") else dtype + config = getattr(old, "config", None) + act_fn = old.act_fn + + ctx = deepspeed.zero.Init() if z3 else nullcontext() + with ctx: + new_module = LinearizedMoeExperts( + num_experts=num_experts, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + act_fn=act_fn, + dtype=old_dtype, + device=None, # let LinearizedMoeExperts pick cuda + config=config, + ) + + setattr(parent, sub_name, new_module) + del old + replaced += 1 + _cleanup() + + print_info(f"[zero3] linearize_moe_experts_empty: replaced {replaced} fused module(s).") + return replaced + + +# --------------------------------------------------------------------------- +# Empty model construction + streaming weight loader +# --------------------------------------------------------------------------- + + +def _resolve_dtype(torch_dtype, config): + if isinstance(torch_dtype, torch.dtype): + return torch_dtype + if isinstance(torch_dtype, str) and torch_dtype != "auto": + return getattr(torch, torch_dtype) + resolved = getattr(config, "torch_dtype", None) or torch.float32 + if isinstance(resolved, str): + return getattr(torch, resolved) + return resolved + + +def zero3_empty_model_from_pretrained( + model_path, + torch_dtype="auto", + trust_remote_code=True, + use_cache=False, + attn_implementation="default", + log_prefix="[zero3]", +): + """Build an EMPTY ZeRO-3 sharded model from a HuggingFace ``model_path``. + + Linearizes all fused MoE experts immediately so subsequent QuantLinear + insertion can iterate flat ``nn.Linear`` modules. Does NOT load + weights — caller must invoke :func:`stream_load_weights`. + """ + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.integrations.deepspeed import HfDeepSpeedConfig # noqa: F401 + + # ``no_init_weights`` / ``no_tie_weights`` moved across transformers + # versions: newest (>=5.x) expose them under ``transformers.initialization``; + # older releases kept them in ``modeling_utils`` or the top-level package. + no_init_weights = None + no_tie_weights = None + for mod in ("transformers.initialization", "transformers.modeling_utils", "transformers"): + try: + m = __import__(mod, fromlist=["no_init_weights"]) + if hasattr(m, "no_init_weights"): + no_init_weights = m.no_init_weights + if hasattr(m, "no_tie_weights"): + no_tie_weights = m.no_tie_weights + except Exception: # noqa: BLE001 + continue + if no_init_weights is not None and no_tie_weights is not None: + break + if no_init_weights is None: + no_init_weights = nullcontext # type: ignore + if no_tie_weights is None: + no_tie_weights = nullcontext # type: ignore + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) + if attn_implementation != "default": + config._attn_implementation = attn_implementation + if use_cache is not None: + config.use_cache = use_cache + + resolved = _resolve_dtype(torch_dtype, config) + print_info( + f"{log_prefix} build empty ZeRO-3 model dtype={resolved} from config " + f"{getattr(config, 'model_type', '?')}" + ) + + # ``from_config`` triggers ``deepspeed.zero.Init`` automatically when + # ``HfTrainerDeepSpeedConfig`` is registered (i.e. when + # is_deepspeed_zero3_enabled() returns True). + with no_init_weights(), no_tie_weights(): + model = AutoModelForCausalLM.from_config( + config, torch_dtype=resolved, trust_remote_code=trust_remote_code, + ) + + # Linearize fused MoE experts BEFORE weight loading so the loader can + # write per-expert slices directly into the new Linear targets. + linearize_moe_experts_empty(model, dtype=resolved) + + return model + + +def _shards(model_path): + """Yield ``(shard_path, [keys])`` for every safetensors shard.""" + from safetensors import safe_open + + index_path = os.path.join(model_path, "model.safetensors.index.json") + if os.path.isfile(index_path): + with open(index_path, "r") as f: + weight_map = json.load(f)["weight_map"] + per_shard = {} + for key, shard in weight_map.items(): + per_shard.setdefault(shard, []).append(key) + for shard in sorted(per_shard): + yield os.path.join(model_path, shard), per_shard[shard] + return + + paths = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) + if not paths: + raise FileNotFoundError(f"No safetensors found under {model_path}") + for shard_path in paths: + with safe_open(shard_path, framework="pt") as r: + yield shard_path, list(r.keys()) + + +def _broadcast_into_target(src, target, *, is_buffer=False, key=None): + """Copy ``src`` (rank0 only, or None on other ranks) into ``target``. + + Handles three cases: + * ZeRO-3 sharded ``Parameter``: gather, rank0 writes, exit gather. + * Regular distributed ``Parameter`` / replicated buffer: rank0 stages, + then broadcast. + * Single-process: direct copy. + """ + dist_active = ( + torch.distributed.is_available() and torch.distributed.is_initialized() + ) + + if is_zero3_param(target): + with gathered_param_if_zero3(target, modifier_rank=0): + if _rank() == 0: + if src is None or src.shape != target.shape: + return False + target.data.copy_(src.to(device=target.device, dtype=target.dtype)) + return True + + # Regular tensor (parameter or buffer). + if dist_active: + if _rank() == 0: + if src is None or (not is_buffer and src.shape != target.shape): + return False + tmp = src.to(device=target.device, dtype=target.dtype).contiguous() + else: + tmp = torch.empty_like(target) + torch.distributed.broadcast(tmp, src=0) + target.data.copy_(tmp) + return True + + if src is None: + return False + target.data.copy_(src.to(device=target.device, dtype=target.dtype)) + return True + + +def stream_load_weights(model, model_path, log_prefix="[zero3]"): + """Stream a HF safetensors checkpoint into ``model``. + + Recognises fused MoE keys ``*.experts.gate_up_proj`` and + ``*.experts.down_proj`` and dispatches the per-expert slices into the + matching :class:`LinearizedMoeExperts` children. All other keys are + matched against ``model.named_parameters()`` / ``named_buffers()``. + + rank0 reads the bytes; ZeRO-3 sharded targets are filled inside + ``GatheredParameters(modifier_rank=0)``; replicated tensors are + broadcast. + """ + from safetensors import safe_open + + name_to_param = dict(model.named_parameters()) + name_to_buffer = dict(model.named_buffers()) + rank = _rank() + + loaded = 0 + skipped = 0 + seen_targets = set() + + for shard_path, keys in _shards(model_path): + with safe_open(shard_path, framework="pt") as reader: + for key in keys: + if key.endswith(".experts.gate_up_proj"): + base = key[: -len(".gate_up_proj")] + src = reader.get_tensor(key) if rank == 0 else None + n_exp = ( + int(src.shape[0]) if src is not None + else _infer_num_experts(base, name_to_param) + ) + for i in range(n_exp): + gkey = f"{base}.{i}.gate_proj.weight" + ukey = f"{base}.{i}.up_proj.weight" + gtgt = name_to_param.get(gkey) + utgt = name_to_param.get(ukey) + if gtgt is None or utgt is None: + skipped += 2 + continue + gsrc = src[i].chunk(2, dim=-2)[0] if src is not None else None + usrc = src[i].chunk(2, dim=-2)[1] if src is not None else None + if _broadcast_into_target(gsrc, gtgt, key=gkey): + seen_targets.add(gkey); loaded += 1 + else: + skipped += 1 + if _broadcast_into_target(usrc, utgt, key=ukey): + seen_targets.add(ukey); loaded += 1 + else: + skipped += 1 + del src + elif key.endswith(".experts.down_proj"): + base = key[: -len(".down_proj")] + src = reader.get_tensor(key) if rank == 0 else None + n_exp = ( + int(src.shape[0]) if src is not None + else _infer_num_experts(base, name_to_param) + ) + for i in range(n_exp): + dkey = f"{base}.{i}.down_proj.weight" + dtgt = name_to_param.get(dkey) + if dtgt is None: + skipped += 1 + continue + dsrc = src[i] if src is not None else None + if _broadcast_into_target(dsrc, dtgt, key=dkey): + seen_targets.add(dkey); loaded += 1 + else: + skipped += 1 + del src + else: + tgt = name_to_param.get(key) + is_buf = False + if tgt is None: + tgt = name_to_buffer.get(key) + is_buf = tgt is not None + if tgt is None: + skipped += 1 + continue + src = reader.get_tensor(key) if rank == 0 else None + if _broadcast_into_target(src, tgt, is_buffer=is_buf, key=key): + seen_targets.add(key); loaded += 1 + else: + skipped += 1 + del src + _cleanup() + print_info(f"{log_prefix} loaded shard {os.path.basename(shard_path)}") + + all_targets = set(name_to_param) | set(name_to_buffer) + missing = sorted(all_targets - seen_targets) + print_info( + f"{log_prefix} stream_load_weights done: " + f"loaded={loaded} skipped={skipped} missing={len(missing)}" + ) + if missing: + print_info(f"{log_prefix} first missing keys: {missing[:10]}") + + try: + model.tie_weights() + except Exception as e: # noqa: BLE001 + print_info(f"{log_prefix} tie_weights skipped: {e}") + + +def _infer_num_experts(base, name_to_param): + prefix = f"{base}." + ids = [] + for name in name_to_param: + if not name.startswith(prefix): + continue + first = name[len(prefix):].split(".", 1)[0] + if first.isdigit(): + ids.append(int(first)) + return (max(ids) + 1) if ids else 0 + + +# --------------------------------------------------------------------------- +# Streaming PTQ-scale loader for QAT warm-start +# --------------------------------------------------------------------------- + + +_SCALE_SUFFIX_RULES = [ + # (suffix_in_ckpt, quantizer_attr_on_QuantLinear, sub_attr, layer_name_rewrite) + (".weight_zero_point", "weight_quantizer", "zero_point", None), + (".input_zero_point", "act_quantizer", "zero_point", None), + (".weight_scale", "weight_quantizer", "scale", None), + (".input_scale", "act_quantizer", "scale", None), + (".k_cache.scale", "qkv_quantizer", "scale", ".k_proj"), + (".v_cache.scale", "qkv_quantizer", "scale", ".v_proj"), +] +# Longest first to avoid '.scale' winning over '.k_cache.scale'. +_SCALE_SUFFIX_RULES.sort(key=lambda r: len(r[0]), reverse=True) + + +def _parse_scale_key(key): + for suffix, qname, sub, rewrite in _SCALE_SUFFIX_RULES: + if key.endswith(suffix): + base = key[: -len(suffix)] + return (base + rewrite if rewrite else base), qname, sub + return None + + +def _expand_scale_targets(layer_name, qname, sub, named_modules): + """Expand a checkpoint key that targets a fused MoE expert tensor into the + matching per-expert linears in the linearized model. For non-MoE keys + returns ``[(layer_name, qname, sub)]`` unchanged.""" + if layer_name in named_modules: + return [(layer_name, qname, sub)] + + # PTQ checkpoint may store scales for the fused expert matrix; map them + # to every per-expert Linear we expanded into. + if layer_name.endswith(".experts.gate_up_proj"): + base = layer_name[: -len(".gate_up_proj")] + return [ + (n, qname, sub) for n in named_modules + if n.startswith(base + ".") and (n.endswith(".gate_proj") or n.endswith(".up_proj")) + ] + if layer_name.endswith(".experts.down_proj"): + base = layer_name[: -len(".down_proj")] + return [ + (n, qname, sub) for n in named_modules + if n.startswith(base + ".") and n.endswith(".down_proj") + ] + return [] + + +def _copy_scale_into(src, target): + """rank0-driven copy of a scale-like tensor into a (possibly ZeRO-3) + Parameter, with shape coercion for scalar/per-tensor mismatch.""" + rank = _rank() + ok = True + with gathered_param_if_zero3(target, modifier_rank=0): + if rank == 0: + s = src + if s.numel() == target.numel(): + s = s.reshape(target.shape) + else: + try: + s = s.expand_as(target).contiguous() + except RuntimeError: + ok = False + if ok: + target.data.copy_(s.to(device=target.device, dtype=target.dtype)) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + device = ( + torch.device("cuda", torch.cuda.current_device()) + if torch.cuda.is_available() else target.device + ) + flag = torch.tensor(int(ok), device=device) + torch.distributed.broadcast(flag, src=0) + ok = bool(flag.item()) + return ok + + +def stream_load_scales(model, ckpt_dir, log_prefix="[zero3]"): + """Read a PTQ "real" checkpoint and write its scale / zero_point / + kv-cache scale tensors into the matching ``QuantLinear`` quantizer + parameters of ``model``. + + Sets ``act_quantizer.init = True`` for every static activation + quantizer that successfully receives a scale, so the lazy-init pass is + skipped. + """ + from safetensors import safe_open + + # Resolve nested layout (some PTQ exporters nest under final_quant_checkpoint/). + if not glob.glob(os.path.join(ckpt_dir, "*.safetensors")): + nested = os.path.join(ckpt_dir, "final_quant_checkpoint") + if os.path.isdir(nested): + ckpt_dir = nested + + files = sorted(glob.glob(os.path.join(ckpt_dir, "*.safetensors"))) + if not files: + raise FileNotFoundError(f"No *.safetensors in {ckpt_dir}") + + # Lazy import to avoid circular dependency: this module is imported by + # angelslim.utils, which is imported very early. + from ..compressor.qat.modules.quantizer import QuantLinear + + named_modules = dict(model.named_modules()) + rank = _rank() + loaded = 0 + skipped = 0 + + for src_file in files: + with safe_open(src_file, framework="pt") as reader: + for key in reader.keys(): + parsed = _parse_scale_key(key) + if parsed is None: + continue + layer_name, qname, sub = parsed + targets = _expand_scale_targets(layer_name, qname, sub, named_modules) + if not targets: + skipped += 1 + continue + src = reader.get_tensor(key) if rank == 0 else None + for tgt_layer, tgt_qname, tgt_sub in targets: + module = named_modules.get(tgt_layer) + if not isinstance(module, QuantLinear): + skipped += 1 + continue + quantizer = getattr(module, tgt_qname, None) + if quantizer is None: + skipped += 1 + continue + target = getattr(quantizer, tgt_sub, None) + if not isinstance(target, torch.nn.Parameter): + skipped += 1 + continue + if _copy_scale_into(src, target): + loaded += 1 + if tgt_qname == "act_quantizer" and tgt_sub == "scale": + quantizer.init = True + else: + skipped += 1 + + print_info( + f"{log_prefix} stream_load_scales: loaded={loaded} skipped={skipped} from {ckpt_dir}" + ) + + +# --------------------------------------------------------------------------- +# Saving a sharded model via the model-specific save_func +# --------------------------------------------------------------------------- + + +def consolidated_state_dict(model): + """rank-0 CPU state_dict for a possibly ZeRO-3 sharded ``model``. + + Other ranks see an empty dict (matching the contract of HF/Trainer + save callbacks). Includes persistent buffers.""" + rank = _rank() + sd = {} + for name, param in model.named_parameters(): + with gathered_param_if_zero3(param): + if rank == 0: + sd[name] = param.detach().cpu().clone() + if rank == 0: + for module_name, module in model.named_modules(): + for buf_name, buf in module.named_buffers(recurse=False): + if buf is None or buf_name in module._non_persistent_buffers_set: + continue + full = f"{module_name}.{buf_name}" if module_name else buf_name + sd[full] = buf.detach().cpu().clone() + return sd + + +def save_via_model_save_func(quant_model, save_func, save_target_dir): + """Invoke ``save_func.save(...)`` with the model's ``state_dict`` patched + to return the consolidated rank-0 dict. + + No-op (delegates straight to ``save_func.save``) when no parameters are + sharded. + """ + if not model_has_zero3_params(quant_model.model): + save_func.save(save_target_dir) + return + + rank = _rank() + sd = consolidated_state_dict(quant_model.model) + hf_model = quant_model.get_model() + original = hf_model.state_dict + + def _patched(*args, **kwargs): + return sd if rank == 0 else {} + + try: + hf_model.state_dict = _patched # type: ignore[method-assign] + if rank == 0: + save_func.save(save_target_dir) + finally: + hf_model.state_dict = original # type: ignore[method-assign] + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + +# --------------------------------------------------------------------------- +# DeepSpeed engine: tolerate scale parameters that appear in multiple param +# groups (the scale tensor itself is unique, but our optimizer construction +# may pick it up from multiple QuantLinear children of a shared MoE expert). +# --------------------------------------------------------------------------- + + +def patch_deepspeed_duplicate_check(): + """No-op DeepSpeed's ``_check_for_duplicates`` so QAT scale parameters + that share storage across modules don't crash optimizer init. + + Idempotent: only patches once per process. + """ + try: + from deepspeed.runtime.engine import DeepSpeedEngine + except Exception as exc: # noqa: BLE001 + print_info(f"[zero3] skip duplicate-check patch: {exc}") + return + if getattr(DeepSpeedEngine, "_angelslim_skip_dup_check", False): + return + + def _noop(self, basic_optimizer): # noqa: ARG001 + return + + DeepSpeedEngine._check_for_duplicates = _noop + DeepSpeedEngine._angelslim_skip_dup_check = True + print_info("[zero3] patched DeepSpeed _check_for_duplicates (idempotent).") diff --git a/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json b/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json new file mode 100644 index 00000000..e513dd7b --- /dev/null +++ b/configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json @@ -0,0 +1,45 @@ +{ + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": "auto", + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none" + }, + "offload_param": { + "device": "none" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml new file mode 100644 index 00000000..da23b3a5 --- /dev/null +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml @@ -0,0 +1,81 @@ +global: + save_path: ./output + +model: + name: Qwen + model_path: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-30B-A3B + trust_remote_code: true + torch_dtype: auto + # ZeRO-3 wants HF to NOT place the model with device_map; we pass + # ``None`` here and the Engine normalizes the string "None"/"distributed". + device_map: None + low_cpu_mem_usage: true + use_cache: false + +compression: + name: QAT + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] + QAT: + hf_dataset: Salesforce/wikitext,wikitext-2-raw-v1 + # REQUIRED under ZeRO-3: bootstrap scales from a previous PTQ "real" + # checkpoint (same model_path works if it already carries scales; else + # point to an AngelSlim PTQ output dir). For the plain base model we + # just fall back to the init-value below, so this file MUST exist or + # the directory must contain *.safetensors even if they only carry + # base weights (scales will be initialised from weight_scale_init_value + # and activation_scale_init_value). + from_ptq_ckpt: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-30B-A3B + training_mode: end2end + dist_mode: hf + save_format: real + do_train: true + loss_type: origin + loss_topk: null + kd_temperature: 1.0 + kd_alpha: 0.5 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: false + # Init values used under ZeRO-3 when weight data is not accessible + # (and as fallback whenever the from_ptq_ckpt does not carry scale + # tensors). + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-4 + lr_scheduler_type: cosine + num_train_epochs: 1 + max_steps: 3 + save_strategy: "no" + warmup_ratio: 0.0 + max_grad_norm: 1.0 + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: true + deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml new file mode 100644 index 00000000..92eeb025 --- /dev/null +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml @@ -0,0 +1,69 @@ +global: + save_path: ./output_4b_zero3 + +model: + name: Qwen + model_path: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-4B + trust_remote_code: true + torch_dtype: auto + device_map: None + low_cpu_mem_usage: true + use_cache: false + +compression: + name: QAT + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens"] + QAT: + hf_dataset: Salesforce/wikitext,wikitext-2-raw-v1 + from_ptq_ckpt: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-4B + training_mode: end2end + dist_mode: hf + save_format: real + do_train: true + loss_type: origin + loss_topk: null + kd_temperature: 1.0 + kd_alpha: 0.5 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: false + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-4 + lr_scheduler_type: cosine + num_train_epochs: 1 + max_steps: 5 + save_strategy: "no" + warmup_ratio: 0.0 + max_grad_norm: 1.0 + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: true + deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 512 + num_samples: 16 + batch_size: 1 diff --git a/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh b/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh new file mode 100755 index 00000000..1b244127 --- /dev/null +++ b/scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +export PYTORCH_ALLOC_CONF="expandable_segments:True" + +NPROC=${NPROC:-8} +CONFIG=${CONFIG:-configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml} + +torchrun --nproc_per_node=${NPROC} \ + tools/run.py \ + -c "${CONFIG}" diff --git a/scripts/qat/run_qat_for_qwen_4b_zero3.sh b/scripts/qat/run_qat_for_qwen_4b_zero3.sh new file mode 100755 index 00000000..45181701 --- /dev/null +++ b/scripts/qat/run_qat_for_qwen_4b_zero3.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Smoke-test launcher: Qwen3-4B with DeepSpeed ZeRO-3. Matches the full +# 30B-A3B flow but on a model small enough to iterate quickly. + +export PYTORCH_ALLOC_CONF="expandable_segments:True" + +NPROC=${NPROC:-2} +CONFIG=${CONFIG:-configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml} + +torchrun --nproc_per_node=${NPROC} \ + tools/run.py \ + -c "${CONFIG}" diff --git a/scripts/qat/test_moe_zero3_build.py b/scripts/qat/test_moe_zero3_build.py new file mode 100644 index 00000000..eb18f64f --- /dev/null +++ b/scripts/qat/test_moe_zero3_build.py @@ -0,0 +1,98 @@ +"""Standalone smoke test: under ZeRO-3, build an empty Qwen3-30B-A3B model +with layers trimmed to 2, linearize MoE, stream weights, and verify the +per-expert Linear shapes are correct. + +Run: + torchrun --nproc_per_node=2 scripts/qat/test_moe_zero3_build.py +""" +import os +import sys +import torch +import torch.distributed as dist + +# Initialise torch distributed first (required by deepspeed.zero.Init). +dist.init_process_group(backend="nccl") +torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + +# Register HfTrainerDeepSpeedConfig so is_deepspeed_zero3_enabled() returns True. +from transformers import Seq2SeqTrainingArguments # noqa: E402 + +ds_config = { + "zero_optimization": {"stage": 3, "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9}, + "train_batch_size": 2, + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "bf16": {"enabled": True}, +} +_hf_args = Seq2SeqTrainingArguments( + output_dir="./tmp_out", deepspeed=ds_config, bf16=True, + per_device_train_batch_size=1, +) + +from transformers import AutoConfig # noqa: E402 +from angelslim.utils import ( # noqa: E402 + is_deepspeed_zero3_enabled, is_zero3_param, + zero3_empty_model_from_pretrained, stream_load_weights, linearize_moe_experts_empty, +) + +assert is_deepspeed_zero3_enabled(), "HF ZeRO-3 not registered" + +MODEL_PATH = "/apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-30B-A3B" + +# Trim layers to 2 for quick iteration. +cfg = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) +cfg.num_hidden_layers = 2 +cfg.use_cache = False + +# Build empty model from the trimmed config. +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers.initialization import no_init_weights, no_tie_weights # noqa: E402 + +with no_init_weights(), no_tie_weights(): + model = AutoModelForCausalLM.from_config(cfg, torch_dtype=torch.bfloat16, trust_remote_code=True) + +if dist.get_rank() == 0: + print("Built empty model:", type(model).__name__) + print("num layers:", len(model.model.layers)) + # Inspect first layer MoE before linearization + layer = model.model.layers[0] + mlp = layer.mlp + print("mlp type:", type(mlp).__name__) + if hasattr(mlp, "experts"): + experts = mlp.experts + print("experts type:", type(experts).__name__) + print("has gate_up_proj:", hasattr(experts, "gate_up_proj")) + if hasattr(experts, "gate_up_proj"): + print("gate_up_proj ds_shape:", getattr(experts.gate_up_proj, "ds_shape", None)) + print("is_zero3:", is_zero3_param(experts.gate_up_proj)) + +replaced = linearize_moe_experts_empty(model, dtype=torch.bfloat16) +dist.barrier() + +if dist.get_rank() == 0: + print(f"Replaced {replaced} fused experts") + layer = model.model.layers[0] + mlp = layer.mlp + experts = mlp.experts + print("After linearize - experts type:", type(experts).__name__) + print("num_experts attr:", experts.num_experts) + # Inspect one expert + e0 = experts[0] + print("expert 0:", type(e0).__name__) + print(" gate_proj weight shape:", e0["gate_proj"].weight.shape, + "is_zero3:", is_zero3_param(e0["gate_proj"].weight)) + print(" gate_proj ds_shape:", getattr(e0["gate_proj"].weight, "ds_shape", None)) + +# Now stream load +stream_load_weights(model, MODEL_PATH, log_prefix=f"[rank{dist.get_rank()}]") +dist.barrier() + +# Note: because we trimmed to 2 layers, the checkpoint has weights for +# layers [0..47], so layers [2..47] will appear as "unused keys" in the +# missing-key summary (not an error). + +if dist.get_rank() == 0: + print("\n=== Build + linearize + stream_load_weights OK ===") + +dist.destroy_process_group() diff --git a/tools/run.py b/tools/run.py index c4135199..be84bf4c 100644 --- a/tools/run.py +++ b/tools/run.py @@ -267,6 +267,34 @@ def weight_only_run(config): ) +def _prewarm_hf_deepspeed_config(config): + """Pre-construct ``Seq2SeqTrainingArguments`` so HF's + ``HfTrainerDeepSpeedConfig`` weak-ref is registered BEFORE + ``from_pretrained`` runs. That is what flips + ``is_deepspeed_zero3_enabled()`` to True and makes our + ``BaseLLMModel.from_pretrained`` take the ZeRO-3 path. + + Returns the constructed TrainingArguments (kept alive via the caller's + local variable) or None if not applicable. + """ + compress_cfg = getattr(config, "compression_config", None) + qat_cfg = getattr(compress_cfg, "QAT", None) if compress_cfg is not None else None + hf_args = getattr(qat_cfg, "hf_args", None) if qat_cfg is not None else None + if not hf_args or not hf_args.get("deepspeed"): + return None + + from transformers import Seq2SeqTrainingArguments + + trainer_args = Seq2SeqTrainingArguments( + output_dir=config.global_config.save_path, + **hf_args, + ) + print_info( + "[DeepSpeed pre-warm] HfTrainerDeepSpeedConfig registered before model load." + ) + return trainer_args + + def run(config): """ Run the LLM compression process based on the provided configuration. @@ -295,6 +323,12 @@ def run(config): weight_only_run(config) return + # QAT + DeepSpeed: register HfTrainerDeepSpeedConfig BEFORE loading the + # model so ``from_pretrained`` takes the ZeRO-3 path. No-op otherwise. + # The returned object must stay alive until after the model is built + # because HF's weak-ref mechanism drops the config otherwise. + _hf_ds_args = _prewarm_hf_deepspeed_config(config) + # Step 2: Execute complete pipeline slim_engine = Engine() @@ -312,6 +346,9 @@ def run(config): attn_implementation=model_config.attn_implementation, deploy_backend=global_config.deploy_backend, ) + # Safe to release now: the model is built and any deepspeed.zero.Init + # effects have already happened on all parameters. + del _hf_ds_args # Step 4: Prepare data (optional custom dataloader) if compress_config.need_dataset: From d57cdd6d1c28c05dd835a0cdd354a08de2c0da9d Mon Sep 17 00:00:00 2001 From: root Date: Sun, 26 Apr 2026 19:59:48 +0800 Subject: [PATCH 02/13] fix bug for zero3 save and dataset labels --- angelslim/compressor/qat/modules/quantizer.py | 3 - angelslim/compressor/qat/qat.py | 162 +++++++++++++++--- angelslim/data/text_dataset.py | 13 +- angelslim/utils/zero3_io.py | 28 ++- ..._fp8_static_end2end_learn_scale_zero3.yaml | 4 +- ..._fp8_static_end2end_learn_scale_zero3.yaml | 4 +- 6 files changed, 175 insertions(+), 39 deletions(-) diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py index e0493709..0c45bf15 100644 --- a/angelslim/compressor/qat/modules/quantizer.py +++ b/angelslim/compressor/qat/modules/quantizer.py @@ -626,9 +626,6 @@ def __init__( ) def forward(self, input: torch.Tensor): - if input.shape[0] == 0: - return self.fwd_func(input, self.weight, self.bias) - weight = self.weight_quantizer(self.weight) if self.use_weight_quant else self.weight if self.use_act_quant: input = self.act_quantizer(input) diff --git a/angelslim/compressor/qat/qat.py b/angelslim/compressor/qat/qat.py index 10e7fc4b..d86fea58 100644 --- a/angelslim/compressor/qat/qat.py +++ b/angelslim/compressor/qat/qat.py @@ -44,6 +44,11 @@ def __init__(self, model, slim_config=None): self.quant_model.init_ptq(slim_config) self.quant_info = self.quant_model.quant_config self.plugin_manager = PluginManager() + # When set, ``save`` will use this rank-0-only state_dict instead of + # walking the model again. Populated by ``convert`` under ZeRO-3 to + # avoid keeping a full CPU copy of every layer's QDQModule on every + # rank. + self._rank0_state_dict = None self._init_plugins() self._init_trainer() @@ -81,7 +86,25 @@ def run(self, dataloader): @staticmethod def _gather_clone(tensor): - """Detach + CPU-clone a tensor, gathering if it is a ZeRO-3 shard.""" + """Detach + CPU-clone a tensor, gathering if it is a ZeRO-3 shard. + + WARNING: every rank gets a full CPU copy. Only safe for SMALL tensors + (e.g. scale Parameters). For large weights under ZeRO-3 use + ``_rank0_gather_clone`` instead. + """ + if tensor is None: + return None + with gathered_param_if_zero3(tensor): + return tensor.detach().cpu().clone() + + @staticmethod + def _sym_gather_clone(tensor): + """Symmetric gather-and-clone: every rank gets a full CPU copy. + + Collective timing is symmetric across ranks (minimising NCCL + stalls). Caller is responsible for dropping the clone on rank>0 + immediately to avoid keeping 'world_size' copies alive. + """ if tensor is None: return None with gathered_param_if_zero3(tensor): @@ -91,29 +114,92 @@ def convert(self): if self.save_fmt not in ("real", "real_and_kvcache"): return - print_info("Start QAT convert: replacing QuantLinear with QDQModule...") + zero3 = model_has_zero3_params(self.quant_model.model) + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) quant_algo = self.quant_info.quant_algo + if not zero3: + # ----- single-GPU / non-ZeRO-3 path: original behaviour ----- + print_info("Start QAT convert: replacing QuantLinear with QDQModule...") + for name, module in [ + (n, m) for n, m in self.quant_model.model.named_modules() + if isinstance(m, QuantLinear) + ]: + weight_scale = ( + module.weight_quantizer.scale.data.clone() + if hasattr(module, "weight_quantizer") else None + ) + input_scale = None + if module.use_act_quant and hasattr(module, "act_quantizer"): + aq = module.act_quantizer + if getattr(aq, "scale", None) is not None: + input_scale = aq.scale.data.clone() + qdq_module = QDQModule( + quant_algo=quant_algo, + weight=module.weight, + weight_scale=weight_scale, + bias=module.bias, + group_size=( + module.weight_quantizer.group_size + if hasattr(module, "weight_quantizer") + and hasattr(module.weight_quantizer, "group_size") + else 128 + ), + input_scale=input_scale, + ) + set_op_by_name(self.quant_model.model, name, qdq_module) + return + + # ----- ZeRO-3 path: every rank gathers + clones per layer (fast, + # NCCL-symmetric), but only rank0 keeps the data by feeding it into + # ``_rank0_state_dict``. rank>0 drops the clone immediately so peak + # CPU remains bounded by ~one layer's worth of tensors per rank. + # Model structure is NOT modified — we stream straight into the + # state_dict and let ``save_via_model_save_func`` patch + # ``state_dict()`` for the underlying save_func. + print_info( + f"[rank{rank}] Start QAT convert (ZeRO-3 mode: stream rank0 " + "state_dict, keep model structure intact)..." + ) + self._rank0_state_dict = {} if rank == 0 else {} + quant_linear_modules = [ - (name, module) - for name, module in self.quant_model.model.named_modules() - if isinstance(module, QuantLinear) + (n, m) for n, m in self.quant_model.model.named_modules() + if isinstance(m, QuantLinear) ] + consumed_prefixes = set() for name, module in quant_linear_modules: - weight = self._gather_clone(module.weight) - bias = self._gather_clone(getattr(module, "bias", None)) + # Symmetric gather: all ranks clone (memcpy, fast) so NCCL + # timing stays tight. This is ~world_size× transient CPU RAM + # for JUST this one layer; we free right after the rank0 + # branch completes. + weight = self._sym_gather_clone(module.weight) + bias = self._sym_gather_clone(getattr(module, "bias", None)) weight_scale = None if hasattr(module, "weight_quantizer"): - weight_scale = self._gather_clone(module.weight_quantizer.scale) - + weight_scale = self._sym_gather_clone(module.weight_quantizer.scale) input_scale = None if module.use_act_quant and hasattr(module, "act_quantizer"): - act_quantizer = module.act_quantizer - if hasattr(act_quantizer, "scale") and act_quantizer.scale is not None: - input_scale = self._gather_clone(act_quantizer.scale) + aq = module.act_quantizer + if getattr(aq, "scale", None) is not None: + input_scale = self._sym_gather_clone(aq.scale) + consumed_prefixes.add(name) + + if rank != 0: + # Drop the clone immediately; next iteration will overwrite + # these locals anyway but be explicit for clarity. + del weight, bias, weight_scale, input_scale + continue + + # rank0 only: run the fp8/int quantize path via a throwaway + # QDQModule, then move its params into the consolidated dict + # and discard the module. qdq_module = QDQModule( quant_algo=quant_algo, weight=weight, @@ -127,10 +213,45 @@ def convert(self): ), input_scale=input_scale, ) - set_op_by_name(self.quant_model.model, name, qdq_module) + for sub_name, p in qdq_module.named_parameters(recurse=False): + self._rank0_state_dict[f"{name}.{sub_name}"] = p.detach().cpu() + for sub_name, b in qdq_module.named_buffers(recurse=False): + if sub_name in qdq_module._non_persistent_buffers_set: + continue + self._rank0_state_dict[f"{name}.{sub_name}"] = b.detach().cpu() + del qdq_module, weight, bias, weight_scale, input_scale + + # Second pass: params/buffers that are NOT inside a QuantLinear + # (embeddings, lm_head, layernorms, MoE router gate, ...). The + # collective order MUST be identical across ranks, so this loop + # runs on every rank; only rank0 keeps the data. + for pname, param in self.quant_model.model.named_parameters(): + if any(pname.startswith(p + ".") for p in consumed_prefixes): + continue + with gathered_param_if_zero3(param): + if rank == 0: + self._rank0_state_dict[pname] = param.detach().cpu().clone() + + if rank == 0: + for module_name, mod in self.quant_model.model.named_modules(): + for buf_name, buf in mod.named_buffers(recurse=False): + if buf is None or buf_name in mod._non_persistent_buffers_set: + continue + full_key = f"{module_name}.{buf_name}" if module_name else buf_name + if full_key in self._rank0_state_dict: + continue + self._rank0_state_dict[full_key] = buf.detach().cpu().clone() + print_info( + f"[zero3] convert done: rank0 state_dict has " + f"{len(self._rank0_state_dict)} tensors." + ) def _save_kv_cache_scales(self, save_path: str): """Extract and save KV cache scales to a safetensors file.""" + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() else 0 + ) kv_scales = {} for name, module in self.quant_model.model.named_modules(): if not isinstance(module, QuantLinear): @@ -146,18 +267,11 @@ def _save_kv_cache_scales(self, save_path: str): continue scale_key = f"{cache_name}.scale" scale_tensor = self._gather_clone(module.qkv_quantizer.scale) - if scale_tensor is not None: + if scale_tensor is not None and rank == 0: kv_scales[scale_key] = scale_tensor.float() - if model_has_zero3_params(self.quant_model.model): - # Only rank0 writes the file. - rank = ( - torch.distributed.get_rank() - if torch.distributed.is_initialized() else 0 - ) - if rank != 0: - return - + if rank != 0: + return os.makedirs(save_path, exist_ok=True) out_file = os.path.join(save_path, "kv_cache_scales.safetensors") save_file(kv_scales, out_file) @@ -180,6 +294,7 @@ def save(self, save_path: str): self.quant_model, save_func, os.path.join(save_path, "final_quant_checkpoint"), + prebuilt_state_dict=self._rank0_state_dict, ) # "save_kvcache_only": only export KV cache scales (kv_cache_scales.safetensors) @@ -193,6 +308,7 @@ def save(self, save_path: str): self.quant_model, save_func, os.path.join(save_path, "final_quant_checkpoint"), + prebuilt_state_dict=self._rank0_state_dict, ) self._save_kv_cache_scales(os.path.join(save_path, "final_quant_checkpoint")) diff --git a/angelslim/data/text_dataset.py b/angelslim/data/text_dataset.py index 510b3a05..99afd64b 100644 --- a/angelslim/data/text_dataset.py +++ b/angelslim/data/text_dataset.py @@ -74,9 +74,8 @@ def _load_hf_dataset(self, data_path: str, num_samples: int, block_size: int = 2 inputs = { "input_ids": torch.tensor(result["input_ids"][i]).unsqueeze(0).to(self.device) } - labels = inputs["input_ids"].roll(shifts=-1, dims=-1) - labels[:, -1] = -100 - inputs["labels"] = labels.to(self.device) + # HF CausalLM models shift labels internally; feed labels == input_ids. + inputs["labels"] = inputs["input_ids"].clone() inputs["attention_mask"] = torch.tensor(result["attention_mask"][i]).to(self.device) self.data.append(inputs) @@ -101,8 +100,8 @@ def _load_parquet_data(self, data_path: str, num_samples: int): if "labels" in df.columns: labels = torch.tensor(df["labels"].iloc[i]).unsqueeze(0) else: - labels = model_inputs["input_ids"].roll(shifts=-1, dims=-1) - labels[:, -1] = -100 + # HF CausalLM models shift labels internally; feed labels == input_ids. + labels = model_inputs["input_ids"].clone() data_item = { "input_ids": model_inputs["input_ids"].to(self.device), @@ -157,8 +156,8 @@ def _load_jsonl_data(self, data_path: str, num_samples: int): padding="max_length", ) - labels = model_inputs["input_ids"].roll(shifts=-1, dims=-1) - labels[:, -1] = -100 + # HF CausalLM models shift labels internally; feed labels == input_ids. + labels = model_inputs["input_ids"].clone() self.data.append( { diff --git a/angelslim/utils/zero3_io.py b/angelslim/utils/zero3_io.py index 37e5b9b6..8872f6b5 100644 --- a/angelslim/utils/zero3_io.py +++ b/angelslim/utils/zero3_io.py @@ -738,19 +738,39 @@ def consolidated_state_dict(model): return sd -def save_via_model_save_func(quant_model, save_func, save_target_dir): +def save_via_model_save_func( + quant_model, save_func, save_target_dir, prebuilt_state_dict=None, +): """Invoke ``save_func.save(...)`` with the model's ``state_dict`` patched - to return the consolidated rank-0 dict. + to return a consolidated rank-0 dict. + + Parameters + ---------- + prebuilt_state_dict : dict | None + If provided (rank 0), use this dict directly and skip the per-param + gather+clone pass over ``model``. This is the **recommended** path + under ZeRO-3 because the caller (typically ``QAT.convert``) has + already produced rank0's full state_dict layer-by-layer with bounded + peak memory. Other ranks may pass ``None``. + + If ``None``, fall back to ``consolidated_state_dict(model)``, which + gathers every parameter once more — avoid this when ``model`` + already holds large materialised tensors on rank 0 (it would double + the peak). No-op (delegates straight to ``save_func.save``) when no parameters are sharded. """ - if not model_has_zero3_params(quant_model.model): + if not model_has_zero3_params(quant_model.model) and prebuilt_state_dict is None: save_func.save(save_target_dir) return rank = _rank() - sd = consolidated_state_dict(quant_model.model) + if prebuilt_state_dict is not None: + sd = prebuilt_state_dict if rank == 0 else {} + else: + sd = consolidated_state_dict(quant_model.model) + hf_model = quant_model.get_model() original = hf_model.state_dict diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml index da23b3a5..8a0d591a 100644 --- a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml @@ -22,7 +22,9 @@ compression: activation: per-tensor ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] QAT: - hf_dataset: Salesforce/wikitext,wikitext-2-raw-v1 + # Leave ``hf_dataset`` unset → End2EndTrainer.prepare_dataset falls back + # to the ``dataset:`` section below (local TextDataset). + hf_dataset: null # REQUIRED under ZeRO-3: bootstrap scales from a previous PTQ "real" # checkpoint (same model_path works if it already carries scales; else # point to an AngelSlim PTQ output dir). For the plain base model we diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml index 92eeb025..c8b89d2a 100644 --- a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml @@ -20,7 +20,9 @@ compression: activation: per-tensor ignore_layers: ["lm_head", "embed_tokens"] QAT: - hf_dataset: Salesforce/wikitext,wikitext-2-raw-v1 + # Leave ``hf_dataset`` unset → End2EndTrainer.prepare_dataset falls back + # to the ``dataset:`` section below (local TextDataset). + hf_dataset: null from_ptq_ckpt: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-4B training_mode: end2end dist_mode: hf From 329f0a7080c63fced573099133bce1c7cd18d76c Mon Sep 17 00:00:00 2001 From: root Date: Sun, 26 Apr 2026 22:41:18 +0800 Subject: [PATCH 03/13] feat(qat): add DeepSpeed ZeRO-3 support with MoE linearisation --- angelslim/compressor/qat/modules/quantizer.py | 30 +- .../compressor/qat/plugins/learnable_scale.py | 4 +- angelslim/compressor/qat/qat.py | 19 +- .../qat/trainers/end2end_trainer.py | 241 +++++++++--- angelslim/data/text_dataset.py | 84 +++-- angelslim/utils/__init__.py | 2 +- angelslim/utils/config_parser.py | 7 + angelslim/utils/zero3_io.py | 87 +++-- ..._fp8_static_end2end_learn_scale_zero3.yaml | 39 +- ..._fp8_static_end2end_learn_scale_zero3.yaml | 13 +- docs/source/features/quantization/index.md | 1 + .../source/features/quantization/qat_zero3.md | 343 ++++++++++++++++++ scripts/qat/test_moe_zero3_build.py | 98 ----- tools/run.py | 4 +- 14 files changed, 717 insertions(+), 255 deletions(-) create mode 100644 docs/source/features/quantization/qat_zero3.md delete mode 100644 scripts/qat/test_moe_zero3_build.py diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py index 0c45bf15..f7754609 100644 --- a/angelslim/compressor/qat/modules/quantizer.py +++ b/angelslim/compressor/qat/modules/quantizer.py @@ -245,9 +245,7 @@ def _init_lwc_params(self, x, config): else: dim1 = 1 - init = ( - torch.ones((dim1, 1), device=device, dtype=torch.float32) * self.lwc_init_value - ) + init = torch.ones((dim1, 1), device=device, dtype=torch.float32) * self.lwc_init_value self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True) self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True) self.sigmoid = nn.Sigmoid() @@ -561,7 +559,14 @@ def fake_quant(self, x): None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) ) scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x) - return self._fake_quant_with_params(x, scale, round_zero_point) + out = self._fake_quant_with_params(x, scale, round_zero_point) + # Scale is kept in fp32 for numerical stability, but multiplying by + # a bf16/fp16 activation upcasts the result. Cast back to the input + # dtype so downstream F.linear / DeepSpeed autocast wrappers see a + # consistent dtype. + if out.dtype != x.dtype: + out = out.to(x.dtype) + return out def forward(self, x: torch.Tensor): if self.bits >= 16: @@ -611,7 +616,10 @@ def __init__( weight_shape = (org_module.out_features, org_module.in_features) if self.use_weight_quant: self.weight_quantizer = Quantizer( - config, quant_info, x=org_module.weight, weight_shape=weight_shape, + config, + quant_info, + x=org_module.weight, + weight_shape=weight_shape, ) if self.use_act_quant: self.act_quantizer = Quantizer(config, quant_info, is_act=True, resume=resume) @@ -629,7 +637,17 @@ def forward(self, input: torch.Tensor): weight = self.weight_quantizer(self.weight) if self.use_weight_quant else self.weight if self.use_act_quant: input = self.act_quantizer(input) - output = self.fwd_func(input, weight, self.bias) + # Defensive dtype alignment: upstream (DeepSpeed ZeRO-3 / HF + # autocast) may have cast ``input`` to fp16 even though we run in + # bf16. Align to the (fake-quantised) weight dtype so F.linear + # stays consistent. + if input.dtype != weight.dtype: + input = input.to(weight.dtype) + # Disable autocast around the matmul so the DeepSpeed + # zero3_linear_wrap wrapper (decorated with ``autocast_custom_fwd``) + # does NOT silently re-cast ``input`` back to the autocast dtype. + with torch.amp.autocast(device_type="cuda", enabled=False): + output = self.fwd_func(input, weight, self.bias) if self.use_qkv_quant: output = self.qkv_quantizer(output) return output diff --git a/angelslim/compressor/qat/plugins/learnable_scale.py b/angelslim/compressor/qat/plugins/learnable_scale.py index 99521691..0f9ef3cc 100644 --- a/angelslim/compressor/qat/plugins/learnable_scale.py +++ b/angelslim/compressor/qat/plugins/learnable_scale.py @@ -79,9 +79,7 @@ def before_train(self, **kwargs): # plan to fill it from a checkpoint (full resume OR PTQ warm-start # OR ZeRO-3 — where lazy_init is impossible). act_preallocate = ( - self.resume_ckpt_dir is not None - or self.from_ptq_ckpt_dir is not None - or zero3 + self.resume_ckpt_dir is not None or self.from_ptq_ckpt_dir is not None or zero3 ) for name, module in self.quant_model.model.named_modules(): if isinstance(module, torch.nn.Linear): diff --git a/angelslim/compressor/qat/qat.py b/angelslim/compressor/qat/qat.py index d86fea58..c5c524cd 100644 --- a/angelslim/compressor/qat/qat.py +++ b/angelslim/compressor/qat/qat.py @@ -115,22 +115,21 @@ def convert(self): return zero3 = model_has_zero3_params(self.quant_model.model) - rank = ( - torch.distributed.get_rank() - if torch.distributed.is_initialized() else 0 - ) + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 quant_algo = self.quant_info.quant_algo if not zero3: # ----- single-GPU / non-ZeRO-3 path: original behaviour ----- print_info("Start QAT convert: replacing QuantLinear with QDQModule...") for name, module in [ - (n, m) for n, m in self.quant_model.model.named_modules() + (n, m) + for n, m in self.quant_model.model.named_modules() if isinstance(m, QuantLinear) ]: weight_scale = ( module.weight_quantizer.scale.data.clone() - if hasattr(module, "weight_quantizer") else None + if hasattr(module, "weight_quantizer") + else None ) input_scale = None if module.use_act_quant and hasattr(module, "act_quantizer"): @@ -167,8 +166,7 @@ def convert(self): self._rank0_state_dict = {} if rank == 0 else {} quant_linear_modules = [ - (n, m) for n, m in self.quant_model.model.named_modules() - if isinstance(m, QuantLinear) + (n, m) for n, m in self.quant_model.model.named_modules() if isinstance(m, QuantLinear) ] consumed_prefixes = set() @@ -248,10 +246,7 @@ def convert(self): def _save_kv_cache_scales(self, save_path: str): """Extract and save KV cache scales to a safetensors file.""" - rank = ( - torch.distributed.get_rank() - if torch.distributed.is_initialized() else 0 - ) + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 kv_scales = {} for name, module in self.quant_model.model.named_modules(): if not isinstance(module, QuantLinear): diff --git a/angelslim/compressor/qat/trainers/end2end_trainer.py b/angelslim/compressor/qat/trainers/end2end_trainer.py index 40387b87..79828981 100644 --- a/angelslim/compressor/qat/trainers/end2end_trainer.py +++ b/angelslim/compressor/qat/trainers/end2end_trainer.py @@ -48,74 +48,180 @@ def __init__(self, *args, loss_config=None, quant_config=None, **kwargs): self.loss_type = str(loss_config.get("loss_type", "origin")).lower() self.loss_topk = loss_config.get("loss_topk") self.kd_temperature = float(loss_config.get("kd_temperature", 1.0)) + # ``kd_alpha`` kept for backward compat but IGNORED when + # ``lm_loss_weight`` / ``kd_loss_weight`` are the (new) source of + # truth. self.kd_alpha = float(loss_config.get("kd_alpha", 0.5)) + self.lm_loss_weight = float(loss_config.get("lm_loss_weight", 1.0)) + self.kd_loss_weight = float(loss_config.get("kd_loss_weight", 0.0)) self.use_weight_quant = quant_config.get("use_weight_quant", False) self.use_activation_quant = quant_config.get("use_activation_quant", False) self.use_qkv_quant = quant_config.get("use_qkv_quant", False) - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if self.loss_type == "origin": - return super().compute_loss( - model, - inputs, - return_outputs=return_outputs, - num_items_in_batch=num_items_in_batch, - ) + # Running metric aggregator keyed by logger mode. + from collections import defaultdict - teacher_logits = self.get_ori_outputs(model, inputs).logits - student_inputs = dict(inputs) - if self.loss_type != "kd": - student_inputs.pop("labels", None) - outputs = model(**student_inputs) - student_logits = outputs.logits + self._qat_metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + + # ------------------------------------------------------------------ + # KD loss helpers + # ------------------------------------------------------------------ + + @staticmethod + def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor: + """Per-token KL(tgt || src) (shape [N]).""" + return F.kl_div(log_p_src, p_tgt, reduction="none").sum(dim=-1) + + def _compute_kd_components(self, student_logits, teacher_logits, labels): + """Return a dict of per-token KD losses computed only on valid + (label != -100) positions. Keys present depend on ``self.loss_type``. + + Always returns ``forward_kl`` and ``backward_kl`` (useful for + logging even when the main kd loss is e.g. ``mse`` or a topk + variant) — but only when kd_loss_weight > 0. + """ + flat_mask = (labels != -100).reshape(-1) + if flat_mask.sum() == 0: + zero = student_logits.new_zeros(()) + return {"loss": zero, "forward_kl": zero, "backward_kl": zero} + # Flatten to [N, V] and keep only valid tokens. + s_flat = student_logits.flatten(0, -2)[flat_mask] + t_flat = teacher_logits.flatten(0, -2)[flat_mask] + valid_labels = labels.reshape(-1)[flat_mask] + + # Diagnostic KL (always computed at T=1 on valid tokens). + s_logp = F.log_softmax(s_flat, dim=-1) + t_logp = F.log_softmax(t_flat, dim=-1) + s_p = s_logp.exp() + t_p = t_logp.exp() + forward_kl = self._kl_per_token(s_logp, t_p).mean() + backward_kl = self._kl_per_token(t_logp, s_p).mean() + + # Main KD loss according to loss_type. if self.loss_type == "kl": - loss = F.kl_div( - F.log_softmax(student_logits.flatten(0, -2), dim=-1), - F.softmax(teacher_logits.flatten(0, -2), dim=-1), - reduction="batchmean", - ) + kd = forward_kl elif self.loss_type == "rkl": - loss = F.kl_div( - F.log_softmax(teacher_logits.flatten(0, -2), dim=-1), - F.softmax(student_logits.flatten(0, -2), dim=-1), - reduction="batchmean", - ) + kd = backward_kl elif self.loss_type == "mse": - loss = F.mse_loss(student_logits, teacher_logits) + kd = F.mse_loss(s_flat, t_flat) elif self.loss_type == "kd": - if getattr(outputs, "loss", None) is None: - raise ValueError("loss_type='kd' requires labels to compute CE loss.") - temperature = max(self.kd_temperature, 1e-6) - alpha = self.kd_alpha - distill_loss = F.kl_div( - F.log_softmax(student_logits.flatten(0, -2) / temperature, dim=-1), - F.softmax(teacher_logits.flatten(0, -2) / temperature, dim=-1), - reduction="batchmean", - ) - loss = outputs.loss * (1 - alpha) + distill_loss * (alpha * temperature * temperature) + # Legacy "kd": temperature-scaled forward KL. Combined loss is + # (alpha*T^2)*KD + (1-alpha)*lm — but we now rely on + # lm/kd_loss_weight for the outer combination, so return just + # the scaled KL here. + T = max(self.kd_temperature, 1e-6) + kd = self._kl_per_token( + F.log_softmax(s_flat / T, dim=-1), + F.softmax(t_flat / T, dim=-1), + ).mean() * (T * T) + elif self.loss_type == "cakld": + # Contextual Asymmetric KL-divergence: per-token mixing of + # forward / reverse KL by teacher's confidence on the label. + per_tok_fkl = self._kl_per_token(s_logp, t_p) + per_tok_bkl = self._kl_per_token(t_logp, s_p) + conf = torch.gather(t_p, dim=-1, index=valid_labels.unsqueeze(-1)).squeeze(-1) # [N] + kd = (conf * per_tok_bkl + (1.0 - conf) * per_tok_fkl).mean() elif self._is_reverse_topk_loss(self.loss_type): - topk = self._resolve_topk(self.loss_type, student_logits.size(-1)) - top_student_logits, indices = student_logits.topk(topk, dim=-1, sorted=False) - top_teacher_logits = teacher_logits.gather(-1, indices) - loss = F.kl_div( - F.log_softmax(top_teacher_logits.flatten(0, -2), dim=-1), - F.softmax(top_student_logits.flatten(0, -2), dim=-1), - reduction="batchmean", - ) + topk = self._resolve_topk(self.loss_type, s_flat.size(-1)) + top_s, idx = s_flat.topk(topk, dim=-1, sorted=False) + top_t = t_flat.gather(-1, idx) + kd = self._kl_per_token( + F.log_softmax(top_t, dim=-1), + F.softmax(top_s, dim=-1), + ).mean() elif self._is_forward_topk_loss(self.loss_type): - topk = self._resolve_topk(self.loss_type, teacher_logits.size(-1)) - top_teacher_logits, indices = teacher_logits.topk(topk, dim=-1, sorted=False) - top_student_logits = student_logits.gather(-1, indices) - loss = F.kl_div( - F.log_softmax(top_student_logits.flatten(0, -2), dim=-1), - F.softmax(top_teacher_logits.flatten(0, -2), dim=-1), - reduction="batchmean", + topk = self._resolve_topk(self.loss_type, t_flat.size(-1)) + top_t, idx = t_flat.topk(topk, dim=-1, sorted=False) + top_s = s_flat.gather(-1, idx) + kd = self._kl_per_token( + F.log_softmax(top_s, dim=-1), + F.softmax(top_t, dim=-1), + ).mean() + else: + raise ValueError( + f"Unsupported QAT kd loss_type: {self.loss_type}. " + "Valid: kl, rkl, mse, kd, cakld, kl_top[_K], r_kl_top[_K]." ) + + return {"loss": kd, "forward_kl": forward_kl, "backward_kl": backward_kl} + + def _record(self, name, value): + if value is None: + return + mode = "train" if self.model.training else "eval" + v = value.detach().float() if isinstance(value, torch.Tensor) else float(value) + if isinstance(v, torch.Tensor): + self._qat_metrics[mode][name].append(v.item()) else: - raise ValueError(f"Unsupported QAT loss_type: {self.loss_type}") + self._qat_metrics[mode][name].append(v) - return (loss, outputs) if return_outputs else loss + # ------------------------------------------------------------------ + # compute_loss + # ------------------------------------------------------------------ + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + labels = inputs.get("labels", None) + lm_on = self.lm_loss_weight > 0.0 + kd_on = self.kd_loss_weight > 0.0 + + # Back-compat: ``loss_type="origin"`` means "pure HF CE loss" → the + # classic SFT path with no distillation. Honour it even when the + # user forgot to set kd_loss_weight=0. + if self.loss_type == "origin": + kd_on = False + + if not lm_on and not kd_on: + raise ValueError("Both lm_loss_weight and kd_loss_weight are 0 — nothing to optimise.") + + # Student forward — always needed. + # HF CausalLM loss is computed when ``labels`` is present in inputs. + student_inputs = dict(inputs) + if not lm_on: + # Still need labels for flat_mask; pop from student kwargs to + # skip HF's internal CE and save some compute. + student_inputs.pop("labels", None) + outputs = model(**student_inputs) + + lm_loss = outputs.loss if lm_on and getattr(outputs, "loss", None) is not None else None + if lm_on and lm_loss is None: + raise ValueError( + "lm_loss_weight > 0 but model did not return a loss — " + "check that ``labels`` is set in the batch." + ) + + kd_info = None + if kd_on: + if labels is None: + raise ValueError("kd_loss_weight > 0 requires ``labels`` in the batch.") + teacher_logits = self.get_ori_outputs(model, inputs).logits + kd_info = self._compute_kd_components( + outputs.logits, + teacher_logits, + labels, + ) + + # Combine. + total = outputs.logits.new_zeros(()) + if lm_loss is not None: + total = total + self.lm_loss_weight * lm_loss + if kd_info is not None: + total = total + self.kd_loss_weight * kd_info["loss"] + + # Logging: record every component whose weight is > 0, plus the + # always-informative forward/backward KL diagnostics when kd is on. + if lm_on and lm_loss is not None: + self._record("lm_loss", lm_loss) + if kd_on and kd_info is not None: + self._record(f"kd/{self.loss_type}", kd_info["loss"]) + # Diagnostic KL(L/R) for any kd variant. Useful to monitor + # teacher-student disagreement independent of the combined + # objective. + self._record("kd/forward_kl", kd_info["forward_kl"]) + self._record("kd/backward_kl", kd_info["backward_kl"]) + self._record("total_loss", total) + + return (total, outputs) if return_outputs else total @staticmethod def _is_forward_topk_loss(loss_type): @@ -153,6 +259,31 @@ def get_ori_outputs(self, model, inputs): ) return outputs + def log(self, logs, start_time=None, *args, **kwargs): + """Inject running QAT loss components (lm_loss / kd/... / kd/forward_kl + / kd/backward_kl / total_loss) into HuggingFace Trainer's log dict. + + Each value is averaged across the steps accumulated since the + previous ``log`` call, then the accumulator is cleared — matching + HF Trainer's behaviour for its built-in ``loss`` key. + """ + mode = "train" if self.model.training else "eval" + bucket = self._qat_metrics.get(mode, {}) + if bucket: + for key, vals in bucket.items(): + if not vals: + continue + avg = sum(vals) / len(vals) + out_key = key if mode == "train" else f"eval_{key}" + logs[out_key] = float(avg) + bucket.clear() + + # Forward to HF Trainer's log. Signature differs across versions. + try: + return super().log(logs, start_time, *args, **kwargs) + except TypeError: + return super().log(logs) + @TrainerFactory.register("end2end") class End2EndTrainer: @@ -174,6 +305,8 @@ def __init__(self, quant_model, config, plugin_manager): "loss_topk": config["compress_config"].QAT.loss_topk, "kd_temperature": config["compress_config"].QAT.kd_temperature, "kd_alpha": config["compress_config"].QAT.kd_alpha, + "lm_loss_weight": config["compress_config"].QAT.lm_loss_weight, + "kd_loss_weight": config["compress_config"].QAT.kd_loss_weight, } self.quant_config = { "use_weight_quant": config["compress_config"] @@ -222,9 +355,7 @@ def _init_optimizer(self): and ("clip_factor_w_max" in n or "clip_factor_w_min" in n), ) lwc_param_count = len(lwc_params) - params.append( - {"params": lwc_params, "weight_decay": wd, "lr": lwc_lr} - ) + params.append({"params": lwc_params, "weight_decay": wd, "lr": lwc_lr}) if not any(group["params"] for group in params): raise ValueError("QAT optimizer has no trainable parameters.") diff --git a/angelslim/data/text_dataset.py b/angelslim/data/text_dataset.py index 99afd64b..67da87fd 100644 --- a/angelslim/data/text_dataset.py +++ b/angelslim/data/text_dataset.py @@ -127,42 +127,82 @@ def _load_jsonl_data(self, data_path: str, num_samples: int): # Prepare messages messages = self._prepare_messages(data) - # Apply chat template - text = self.processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + # Find the LAST assistant turn — loss is computed ONLY on + # this reply. Everything before it (system + user(s) + + # earlier assistant(s)) serves as prompt context. + last_assistant_idx = None + for idx, item in enumerate(messages): + if item["role"] == "assistant": + last_assistant_idx = idx + if last_assistant_idx is None: + # No assistant turn -> nothing to supervise; skip. + continue + prompt_messages = messages[:last_assistant_idx] + assistant_msg = messages[last_assistant_idx] + + # Tokenize the prompt (up to the generation marker) and the + # full conversation separately so we know exactly where the + # assistant reply starts. + prompt_text = self.processor.apply_chat_template( + prompt_messages, tokenize=False, add_generation_prompt=True + ) + full_messages = prompt_messages + [assistant_msg] + full_text = self.processor.apply_chat_template( + full_messages, tokenize=False, add_generation_prompt=False ) - thinking_data = False - for dic in messages: - if dic["role"] == "assistant": - if "" and "" in dic["content"]: - thinking_data = True - break + # Legacy branch: thinking-style data without a chat template. + thinking_data = any( + m["role"] == "assistant" + and "" in m.get("content", "") + and "" in m.get("content", "") + for m in messages + ) if thinking_data: - text = self.processor.bos_token if self.processor.bos_token is not None else "" - for dic in messages: - if dic["role"] == "system": - text += dic["content"] - elif dic["role"] == "user": - text = text + "<|User|>" + dic["content"] + "<|Assistant|>" - elif dic["role"] == "assistant": - text = text + dic["content"] + self.processor.eos_token + bos = self.processor.bos_token or "" + prompt_text = bos + for m in prompt_messages: + if m["role"] == "system": + prompt_text += m["content"] + elif m["role"] == "user": + prompt_text += "<|User|>" + m["content"] + "<|Assistant|>" + elif m["role"] == "assistant": + prompt_text += m["content"] + self.processor.eos_token + full_text = prompt_text + assistant_msg["content"] + self.processor.eos_token + + # Token-level prompt length: count tokens in ``prompt_text`` + # without special-token insertion so it aligns with the + # prefix of the tokenization of ``full_text``. + prompt_ids = self.processor( + prompt_text, + add_special_tokens=False, + return_tensors=None, + )["input_ids"] + prompt_len = len(prompt_ids) model_inputs = self.processor( - text=[text], + text=[full_text], return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length", ) - # HF CausalLM models shift labels internally; feed labels == input_ids. - labels = model_inputs["input_ids"].clone() + # Build labels: HF CausalLM shifts labels internally, so + # the label at position ``t`` supervises the prediction of + # ``input_ids[t+1]``. Positions before (and at) the end of + # the prompt are set to -100 so they contribute no loss. + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + labels = input_ids.clone() + labels[:, :prompt_len] = -100 + # Also mask padding tokens. + labels[attention_mask == 0] = -100 self.data.append( { - "input_ids": model_inputs["input_ids"].to(self.device), - "attention_mask": model_inputs["attention_mask"].to(self.device), + "input_ids": input_ids.to(self.device), + "attention_mask": attention_mask.to(self.device), "labels": labels.to(self.device), } ) diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index 46e1c0d9..78d2771e 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -30,13 +30,13 @@ from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 +from .zero3_io import LinearizedMoeExperts # noqa: F401 from .zero3_io import consolidated_state_dict # noqa: F401 from .zero3_io import gathered_param_if_zero3 # noqa: F401 from .zero3_io import gathered_params_if_zero3 # noqa: F401 from .zero3_io import is_deepspeed_zero3_enabled # noqa: F401 from .zero3_io import is_zero3_param # noqa: F401 from .zero3_io import linearize_moe_experts_empty # noqa: F401 -from .zero3_io import LinearizedMoeExperts # noqa: F401 from .zero3_io import model_has_zero3_params # noqa: F401 from .zero3_io import patch_deepspeed_duplicate_check # noqa: F401 from .zero3_io import save_via_model_save_func # noqa: F401 diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index 926284bc..c7aad100 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -282,6 +282,13 @@ class QATTrainingConfig: loss_topk: Optional[int] = None kd_temperature: float = field(default=1.0) kd_alpha: float = field(default=0.5) + # ---- new loss-weight controls (compose LM + KD loss) ---- + # lm_loss_weight: weight on the HF CausalLM CE loss (labels must be set). + # kd_loss_weight: weight on the chosen distillation loss (kl / rkl / mse + # / kl_top_K / r_kl_top_K / cakld / jsd ...). Only loss components + # with a strictly-positive weight are computed AND logged. + lm_loss_weight: float = field(default=1.0) + kd_loss_weight: float = field(default=0.0) hf_args: Dict[str, Any] = field(default_factory=dict) diff --git a/angelslim/utils/zero3_io.py b/angelslim/utils/zero3_io.py index 8872f6b5..455d4433 100644 --- a/angelslim/utils/zero3_io.py +++ b/angelslim/utils/zero3_io.py @@ -174,16 +174,25 @@ def __init__( expert = nn.ModuleDict( { "gate_proj": nn.Linear( - self.hidden_dim, self.intermediate_dim, - bias=False, dtype=dtype, device=device, + self.hidden_dim, + self.intermediate_dim, + bias=False, + dtype=dtype, + device=device, ), "up_proj": nn.Linear( - self.hidden_dim, self.intermediate_dim, - bias=False, dtype=dtype, device=device, + self.hidden_dim, + self.intermediate_dim, + bias=False, + dtype=dtype, + device=device, ), "down_proj": nn.Linear( - self.intermediate_dim, self.hidden_dim, - bias=False, dtype=dtype, device=device, + self.intermediate_dim, + self.hidden_dim, + bias=False, + dtype=dtype, + device=device, ), } ) @@ -200,9 +209,7 @@ def forward( ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot( - top_k_index, num_classes=self.num_experts - ) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() @@ -234,8 +241,14 @@ def _is_fused_moe_experts(module) -> bool: plus ``num_experts``, ``hidden_dim``, ``intermediate_dim``, ``act_fn``.""" if isinstance(module, LinearizedMoeExperts): return False - required = ("gate_up_proj", "down_proj", "num_experts", "hidden_dim", - "intermediate_dim", "act_fn") + required = ( + "gate_up_proj", + "down_proj", + "num_experts", + "hidden_dim", + "intermediate_dim", + "act_fn", + ) return all(hasattr(module, a) for a in required) @@ -380,7 +393,9 @@ def zero3_empty_model_from_pretrained( # is_deepspeed_zero3_enabled() returns True). with no_init_weights(), no_tie_weights(): model = AutoModelForCausalLM.from_config( - config, torch_dtype=resolved, trust_remote_code=trust_remote_code, + config, + torch_dtype=resolved, + trust_remote_code=trust_remote_code, ) # Linearize fused MoE experts BEFORE weight loading so the loader can @@ -422,9 +437,7 @@ def _broadcast_into_target(src, target, *, is_buffer=False, key=None): then broadcast. * Single-process: direct copy. """ - dist_active = ( - torch.distributed.is_available() and torch.distributed.is_initialized() - ) + dist_active = torch.distributed.is_available() and torch.distributed.is_initialized() if is_zero3_param(target): with gathered_param_if_zero3(target, modifier_rank=0): @@ -481,7 +494,8 @@ def stream_load_weights(model, model_path, log_prefix="[zero3]"): base = key[: -len(".gate_up_proj")] src = reader.get_tensor(key) if rank == 0 else None n_exp = ( - int(src.shape[0]) if src is not None + int(src.shape[0]) + if src is not None else _infer_num_experts(base, name_to_param) ) for i in range(n_exp): @@ -495,11 +509,13 @@ def stream_load_weights(model, model_path, log_prefix="[zero3]"): gsrc = src[i].chunk(2, dim=-2)[0] if src is not None else None usrc = src[i].chunk(2, dim=-2)[1] if src is not None else None if _broadcast_into_target(gsrc, gtgt, key=gkey): - seen_targets.add(gkey); loaded += 1 + seen_targets.add(gkey) + loaded += 1 else: skipped += 1 if _broadcast_into_target(usrc, utgt, key=ukey): - seen_targets.add(ukey); loaded += 1 + seen_targets.add(ukey) + loaded += 1 else: skipped += 1 del src @@ -507,7 +523,8 @@ def stream_load_weights(model, model_path, log_prefix="[zero3]"): base = key[: -len(".down_proj")] src = reader.get_tensor(key) if rank == 0 else None n_exp = ( - int(src.shape[0]) if src is not None + int(src.shape[0]) + if src is not None else _infer_num_experts(base, name_to_param) ) for i in range(n_exp): @@ -518,7 +535,8 @@ def stream_load_weights(model, model_path, log_prefix="[zero3]"): continue dsrc = src[i] if src is not None else None if _broadcast_into_target(dsrc, dtgt, key=dkey): - seen_targets.add(dkey); loaded += 1 + seen_targets.add(dkey) + loaded += 1 else: skipped += 1 del src @@ -533,7 +551,8 @@ def stream_load_weights(model, model_path, log_prefix="[zero3]"): continue src = reader.get_tensor(key) if rank == 0 else None if _broadcast_into_target(src, tgt, is_buffer=is_buf, key=key): - seen_targets.add(key); loaded += 1 + seen_targets.add(key) + loaded += 1 else: skipped += 1 del src @@ -561,7 +580,7 @@ def _infer_num_experts(base, name_to_param): for name in name_to_param: if not name.startswith(prefix): continue - first = name[len(prefix):].split(".", 1)[0] + first = name[len(prefix) :].split(".", 1)[0] if first.isdigit(): ids.append(int(first)) return (max(ids) + 1) if ids else 0 @@ -575,11 +594,11 @@ def _infer_num_experts(base, name_to_param): _SCALE_SUFFIX_RULES = [ # (suffix_in_ckpt, quantizer_attr_on_QuantLinear, sub_attr, layer_name_rewrite) (".weight_zero_point", "weight_quantizer", "zero_point", None), - (".input_zero_point", "act_quantizer", "zero_point", None), - (".weight_scale", "weight_quantizer", "scale", None), - (".input_scale", "act_quantizer", "scale", None), - (".k_cache.scale", "qkv_quantizer", "scale", ".k_proj"), - (".v_cache.scale", "qkv_quantizer", "scale", ".v_proj"), + (".input_zero_point", "act_quantizer", "zero_point", None), + (".weight_scale", "weight_quantizer", "scale", None), + (".input_scale", "act_quantizer", "scale", None), + (".k_cache.scale", "qkv_quantizer", "scale", ".k_proj"), + (".v_cache.scale", "qkv_quantizer", "scale", ".v_proj"), ] # Longest first to avoid '.scale' winning over '.k_cache.scale'. _SCALE_SUFFIX_RULES.sort(key=lambda r: len(r[0]), reverse=True) @@ -605,13 +624,15 @@ def _expand_scale_targets(layer_name, qname, sub, named_modules): if layer_name.endswith(".experts.gate_up_proj"): base = layer_name[: -len(".gate_up_proj")] return [ - (n, qname, sub) for n in named_modules + (n, qname, sub) + for n in named_modules if n.startswith(base + ".") and (n.endswith(".gate_proj") or n.endswith(".up_proj")) ] if layer_name.endswith(".experts.down_proj"): base = layer_name[: -len(".down_proj")] return [ - (n, qname, sub) for n in named_modules + (n, qname, sub) + for n in named_modules if n.startswith(base + ".") and n.endswith(".down_proj") ] return [] @@ -637,7 +658,8 @@ def _copy_scale_into(src, target): if torch.distributed.is_available() and torch.distributed.is_initialized(): device = ( torch.device("cuda", torch.cuda.current_device()) - if torch.cuda.is_available() else target.device + if torch.cuda.is_available() + else target.device ) flag = torch.tensor(int(ok), device=device) torch.distributed.broadcast(flag, src=0) @@ -739,7 +761,10 @@ def consolidated_state_dict(model): def save_via_model_save_func( - quant_model, save_func, save_target_dir, prebuilt_state_dict=None, + quant_model, + save_func, + save_target_dir, + prebuilt_state_dict=None, ): """Invoke ``save_func.save(...)`` with the model's ``state_dict`` patched to return a consolidated rank-0 dict. diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml index 8a0d591a..f2940943 100644 --- a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml @@ -3,11 +3,12 @@ global: model: name: Qwen - model_path: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-30B-A3B + # Local directory or HuggingFace repo id of the base Qwen3-30B-A3B model. + model_path: Qwen/Qwen3-30B-A3B trust_remote_code: true torch_dtype: auto - # ZeRO-3 wants HF to NOT place the model with device_map; we pass - # ``None`` here and the Engine normalizes the string "None"/"distributed". + # ZeRO-3 wants HF to NOT place the model with ``device_map``. Pass + # ``None`` here; the Engine normalises the string "None"/"distributed". device_map: None low_cpu_mem_usage: true use_cache: false @@ -22,34 +23,35 @@ compression: activation: per-tensor ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] QAT: - # Leave ``hf_dataset`` unset → End2EndTrainer.prepare_dataset falls back - # to the ``dataset:`` section below (local TextDataset). + # Leave ``hf_dataset`` unset → the trainer falls back to the local + # TextDataset defined in the ``dataset:`` section below. hf_dataset: null - # REQUIRED under ZeRO-3: bootstrap scales from a previous PTQ "real" - # checkpoint (same model_path works if it already carries scales; else - # point to an AngelSlim PTQ output dir). For the plain base model we - # just fall back to the init-value below, so this file MUST exist or - # the directory must contain *.safetensors even if they only carry - # base weights (scales will be initialised from weight_scale_init_value - # and activation_scale_init_value). - from_ptq_ckpt: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-30B-A3B + # REQUIRED under ZeRO-3: warm-start scales from a previous PTQ "real" + # checkpoint. Produce it with e.g. + # python tools/run.py -c configs/qwen3/ptq/fp8_static/qwen3-a3b_fp8_static.yaml + from_ptq_ckpt: ./output_ptq_30b/qwen3-a3b_fp8_static training_mode: end2end dist_mode: hf save_format: real do_train: true - loss_type: origin + # KD + LM loss combination: loss_type picks the KD variant, + # lm_loss_weight / kd_loss_weight control the outer mix. + # loss_type choices: kl | rkl | mse | cakld | kd | kl_top_K | r_kl_top_K + loss_type: cakld loss_topk: null kd_temperature: 1.0 kd_alpha: 0.5 + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 plugin_config: enable_scale: true quant_config: use_weight_quant: true use_activation_quant: true use_qkv_quant: false - # Init values used under ZeRO-3 when weight data is not accessible - # (and as fallback whenever the from_ptq_ckpt does not carry scale - # tensors). + # Init values used when weight data is not accessible (e.g. under + # ZeRO-3) and for quantizer parameters missing from the PTQ + # checkpoint. weight_scale_init_value: 0.1 activation_scale_init_value: 0.1 learnable: @@ -59,11 +61,12 @@ compression: norm: false lwc: false hf_args: + bf16: true logging_steps: 1 logging_first_step: true per_device_train_batch_size: 1 gradient_accumulation_steps: 1 - learning_rate: 1.0e-4 + learning_rate: 1.0e-6 lr_scheduler_type: cosine num_train_epochs: 1 max_steps: 3 diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml index c8b89d2a..27109b26 100644 --- a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml @@ -3,7 +3,7 @@ global: model: name: Qwen - model_path: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-4B + model_path: Qwen/Qwen3-4B trust_remote_code: true torch_dtype: auto device_map: None @@ -20,18 +20,18 @@ compression: activation: per-tensor ignore_layers: ["lm_head", "embed_tokens"] QAT: - # Leave ``hf_dataset`` unset → End2EndTrainer.prepare_dataset falls back - # to the ``dataset:`` section below (local TextDataset). hf_dataset: null - from_ptq_ckpt: /apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-4B + from_ptq_ckpt: ./output_ptq/qwen3-4b_fp8_static training_mode: end2end dist_mode: hf save_format: real do_train: true - loss_type: origin + loss_type: cakld loss_topk: null kd_temperature: 1.0 kd_alpha: 0.5 + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 plugin_config: enable_scale: true quant_config: @@ -47,11 +47,12 @@ compression: norm: false lwc: false hf_args: + bf16: true logging_steps: 1 logging_first_step: true per_device_train_batch_size: 1 gradient_accumulation_steps: 1 - learning_rate: 1.0e-4 + learning_rate: 1.0e-6 lr_scheduler_type: cosine num_train_epochs: 1 max_steps: 5 diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index 2048f22d..3878481c 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -13,5 +13,6 @@ awq gptq fp8_lepto qat +qat_zero3 daq ::: diff --git a/docs/source/features/quantization/qat_zero3.md b/docs/source/features/quantization/qat_zero3.md new file mode 100644 index 00000000..ae648b54 --- /dev/null +++ b/docs/source/features/quantization/qat_zero3.md @@ -0,0 +1,343 @@ +# QAT + DeepSpeed ZeRO-3 + +## 概述 + +本文档描述 AngelSlim QAT 模块的 **DeepSpeed ZeRO-3** 支持。核心动机:当基础模型体量超过单卡显存(例如 Qwen3-30B-A3B 这类 MoE 模型)时,需要把模型参数、梯度、优化器状态都切片到多张 GPU 上,HuggingFace + DeepSpeed 的 ZeRO-3 是一套成熟的方案,但直接用在 QAT 流程里会遇到以下问题: + +1. HuggingFace `from_pretrained` 在 ZeRO-3 下会让每个 rank 先在 CPU 上加载**完整** state_dict,再 partition,峰值内存约 `world_size × model_size`,大模型必然 OOM。 +2. QAT 会把原始 `nn.Linear` 替换成 `QuantLinear`,又会把 fused MoE expert 拆成 per-expert `nn.Linear`。这些替换操作如果在 ZeRO-3 sharded 参数上执行,会同时触发切片 / 合并 / 重分发,逻辑极易出错、峰值不可控。 +3. QAT 的 activation scale 默认通过 forward 校准(lazy init),在 ZeRO-3 下无法运行;weight scale 的初始化依赖读取完整权重,也不可行。 +4. `convert()` 把 `QuantLinear` 转为 `QDQModule`,以及 `save_format=real` 导出压缩权重时,需要在多卡之间合并参数并保持 CPU 内存在 **rank0 一份**。 + +针对上述问题,本实现的关键设计如下: + +- **每个 rank 独立构造一个空模型**(不读磁盘权重),通过 `deepspeed.zero.Init` 立即 partition,峰值内存仅 `model_size / world_size`。 +- **fused MoE 拆解发生在空模型阶段**(duck-typing 识别 `gate_up_proj / down_proj / num_experts / hidden_dim / intermediate_dim / act_fn`),新建的 per-expert `nn.Linear` 被 `deepspeed.zero.Init` 立即切片,**不需要从旧的 fused tensor 拷贝数据**。 +- **权重和 scale 都通过 safetensors 流式加载**:rank0 按文件顺序读取,其它 rank 通过 `GatheredParameters(modifier_rank=0)` 接收数据。fused MoE 的 `.experts.gate_up_proj` / `.experts.down_proj` key 会被自动切片写入每个 per-expert target。 +- **ZeRO-3 下 scale 强制从外部 PTQ checkpoint 加载**(`from_ptq_ckpt`),跳过 forward 校准;未命中的 scale 按 `weight_scale_init_value` / `activation_scale_init_value` 填充。 +- **`convert / save` 只在 rank0 构造 `QDQModule` 并直接写入合并 state_dict**,其它 rank 只参与 NCCL gather 的 collective,不持有 CPU 数据。 + +整体流程下,非 ZeRO-3 路径的行为与 main 分支完全保持一致——所有 ZeRO-3 逻辑都收敛在一个新文件 `angelslim/utils/zero3_io.py` 中,其他文件只做少量薄调用。 + +## 架构设计 + +### 模块目录结构 + +``` +angelslim/ +├── utils/ +│ └── zero3_io.py # ALL ZeRO-3 helpers: detection, +│ # gather/scatter, empty-model build, +│ # streaming weight/scale loaders, +│ # consolidated save, optimizer patch. +├── compressor/qat/ +│ ├── qat.py # ZeRO-3 branch for convert() / save() +│ ├── modules/quantizer.py # init scales from init_value when +│ # weight data is not accessible; dtype +│ # alignment for DeepSpeed autocast. +│ ├── plugins/learnable_scale.py # stream_load_scales + skip lazy_init +│ │ # + gathered quant_inplace +│ └── trainers/end2end_trainer.py # lm_loss + kd_loss composition with +│ # cakld support + per-component logging +├── data/text_dataset.py # supervise ONLY the last assistant turn +├── models/base_model.py # from_pretrained → ZeRO-3 path +├── engine.py # normalise ``device_map`` string +└── utils/ + ├── config_parser.py # QATTrainingConfig new fields + ├── utils.py # set_op_by_name handles string-indexed + └── __init__.py # re-export zero3_io helpers +``` + +### 执行流程(ZeRO-3 路径) + +``` +tools/run.py + └── _prewarm_hf_deepspeed_config() # register HfTrainerDeepSpeedConfig + └── Engine.prepare_model() + └── BaseLLMModel.from_pretrained() # ZeRO-3 branch + ├── zero3_empty_model_from_pretrained() + │ ├── AutoModelForCausalLM.from_config(...) + │ │ # triggers deepspeed.zero.Init for every Parameter + │ └── linearize_moe_experts_empty() + │ # fused Qwen3MoeExperts → empty LinearizedMoeExperts + └── stream_load_weights() # rank0 reads safetensors + └── QAT.__init__() + ├── init_ptq() → Qwen.replace_moe() # no-op: already linearised + └── register LearnableScalePlugin(from_ptq_ckpt_dir=...) + └── LearnableScalePlugin.before_train() + ├── replace nn.Linear with QuantLinear + │ └── Quantizer allocates scale Parameters using init_value + │ (no dependency on weight data) + └── stream_load_scales(from_ptq_ckpt) # fill weight/act/kv scales + └── End2EndTrainer.prepare_trainer() + ├── _init_optimizer() with id-deduped scale/LWC params + ├── patch_deepspeed_duplicate_check() # scales may be tied + └── HF Trainer.train() + # student + teacher forward → lm_loss + kd_loss composition + └── QAT.convert() + QAT.save() + # rank0-only QDQModule + consolidated state_dict → single rank write +``` + +## 使用方法 + +### 前置条件 + +1. 安装依赖:`deepspeed`、`safetensors`、`compressed-tensors`(可选,用于读取导出的 fp8 checkpoint)。 +2. 硬件:支持 NCCL 的多 GPU 节点;ZeRO-3 路径要求 `torchrun --nproc_per_node=N`。 + +### 完整两阶段流程 + +#### 阶段 1:PTQ 校准生成初始 scale + +ZeRO-3 QAT 启动时**不再**跑 forward 校准,因此必须先产出一个带 scale 的 PTQ checkpoint。使用现有 PTQ 配置(单卡即可): + +```bash +python tools/run.py \ + -c configs/qwen3/ptq/fp8_static/qwen3-a3b_fp8_static.yaml \ + --model-path /path/to/Qwen3-30B-A3B \ + --save-path ./output_ptq_30b +``` + +产出的 `./output_ptq_30b/qwen3-a3b_fp8_static/model-*.safetensors` 中包含 `.weight_scale` 和 `.input_scale`。 + +#### 阶段 2:ZeRO-3 QAT + +```bash +bash scripts/qat/run_qat_for_qwen_30b_a3b_zero3.sh +``` + +或直接: + +```bash +torchrun --nproc_per_node=8 tools/run.py \ + -c configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml +``` + +训练完成后,压缩权重会保存到 `./output/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3/final_quant_checkpoint/`。 + +### 最小 4B 烟囱测试 + +快速验证流程(2 张卡、5 步训练): + +```bash +# Step 1: PTQ +python tools/run.py \ + -c configs/qwen3/ptq/fp8_static/qwen3-4b_fp8_static.yaml \ + --model-path Qwen/Qwen3-4B \ + --save-path ./output_ptq + +# Step 2: QAT +bash scripts/qat/run_qat_for_qwen_4b_zero3.sh +``` + +## 配置说明 + +### ZeRO-3 QAT 新增 / 修改字段(`compression.QAT`) + +| 字段 | 类型 | 默认值 | 描述 | +|------|------|--------|------| +| `from_ptq_ckpt` | str / null | `null` | PTQ `save_format="real"` 产出目录。ZeRO-3 下**必填**,否则 `before_train` 会报错。目录可指向最顶层(自动识别嵌套的 `final_quant_checkpoint/`)。 | +| `lm_loss_weight` | float | `1.0` | HF CausalLM CE loss 的权重。为 0 则不计算 lm loss,也不出现在日志中。 | +| `kd_loss_weight` | float | `0.0` | KD loss 的权重(使用 `loss_type` 选定的 KD 变体)。为 0 则不计算 KD,也不启动 teacher forward。 | + +注:`lm_loss_weight` / `kd_loss_weight` 任意一个 > 0 即可参与训练;两者都为 0 会在 `compute_loss` 中抛错。 + +### KD 变体(`loss_type`) + +| `loss_type` | 描述 | +|-------------|------| +| `kl` | `KL(teacher || student)`(forward KL),per-valid-token 平均 | +| `rkl` | `KL(student || teacher)`(reverse / backward KL),per-valid-token 平均 | +| `mse` | student 与 teacher logits 的 MSE(per-valid-token 平均) | +| `cakld` | Confidence-Aware KL Distillation:按 teacher 在 label 上的概率做 token-wise 的 `fkl` / `rkl` 混合,`conf * rkl + (1 - conf) * fkl` | +| `kd` | 经典 temperature KD:`T² * KL(soft_student || soft_teacher)`,保留兼容 | +| `kl_top_K` / `r_kl_top_K` | top-K token 上的 forward / reverse KL。`K` 可写在 `loss_type` 字符串里(例如 `kl_top_1000`),或通过 `loss_topk` 字段指定 | +| `origin` | 纯 HF CE loss(等价于 `kd_loss_weight = 0`) | + +选择 `loss_type = cakld` 时的核心公式(参考 `_compute_kd_components`): + +```python +# 仅在 labels != -100 的 token 上计算 +forward_kl = KL(log_softmax(student), softmax(teacher)) +backward_kl = KL(log_softmax(teacher), softmax(student)) +conf = softmax(teacher).gather(-1, label) # teacher 对目标 token 的置信度 +cakld = (conf * backward_kl + (1 - conf) * forward_kl).mean() +``` + +### 训练日志 + +`QATSeq2SeqTrainer.log()` 会自动把下列指标注入 HF Trainer 的标准日志字典(所以 wandb / console / tqdm 都能看到),仅当对应权重 > 0 时才会出现: + +| 指标 | 含义 | +|------|------| +| `lm_loss` | HF CausalLM CE loss(仅对 assistant 回复位置计算,见下) | +| `kd/` | 当前选定的 KD 主 loss(`cakld` / `kl` / ...) | +| `kd/forward_kl` | 诊断用:`KL(teacher || student)`,始终在 `kd_loss_weight > 0` 时打印 | +| `kd/backward_kl` | 诊断用:`KL(student || teacher)` | +| `total_loss` | `lm_loss_weight * lm_loss + kd_loss_weight * kd/` | + +示例输出(30B-A3B,3 步训练): + +``` +{'loss': 2.08, 'grad_norm': 43.1, 'learning_rate': 1e-6, + 'lm_loss': 1.23, 'kd/cakld': 0.0075, 'kd/forward_kl': 0.0075, 'kd/backward_kl': 0.0075, + 'total_loss': 1.24, 'epoch': 0.03} +``` + +### Dataset:仅监督最后一个 assistant 回复 + +对于 JSONL 格式的 SFT 数据(`messages` / `conversations` / `input+output` 三种 schema),`TextDataset._load_jsonl_data` 现在只对**最后一个 assistant 回复**位置计算 loss: + +- 拼接 prompt(对话中最后一个 assistant 之前的所有 turn)并通过 `apply_chat_template(..., add_generation_prompt=True)` tokenize,得到 `prompt_len`。 +- 拼接完整对话(含最后 assistant)tokenize 得到 `input_ids`,`labels = input_ids.clone()`。 +- 把 `labels[:, :prompt_len]` 和 padding 位置都置为 `-100`,HF CausalLM loss 会自动忽略。 + +这与 HF CausalLM 内部的 shift 行为一致(`shift_logits[..., :-1]` 对齐 `shift_labels[..., 1:]`),**不需要**手动 `roll`。 + +### DeepSpeed 配置 + +参考 `configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json`。关键项: + +```json +{ + "bf16": {"enabled": "auto"}, + "zero_optimization": { + "stage": 3, + "stage3_gather_16bit_weights_on_model_save": true, + "overlap_comm": true, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto" +} +``` + +`hf_args` 中需要设 `bf16: true`(或 `fp16: true`)显式开启混合精度,否则 DeepSpeed 的 `zero3_linear_wrap` 会默认用 fp16 autocast,而模型权重是 bf16,导致 dtype 失配。 + +### Quantizer 初始值 + +当 ZeRO-3 启用(或 weight 是 ZeRO-3 sharded / meta / 0 numel 的任一情况)时,`Quantizer._init_quant_params` 不再依赖权重数值,而是直接按 shape 创建 `nn.Parameter`: + +| 配置项 | 默认值 | 描述 | +|------|--------|------| +| `weight_scale_init_value` | `1.0` | Weight quantizer scale 的初始值,在 `from_ptq_ckpt` 未命中时保留 | +| `activation_scale_init_value` | `1.0` | Activation quantizer scale 的初始值 | + +典型场景下 PTQ 产出的 weight scale 约为 `max(|W|) / 448 ≈ 1e-3`,建议在 yaml 里把 init value 设为 `0.1` 左右作为保底(`from_ptq_ckpt` 命中时这些值会被覆盖)。 + +### 示例 yaml 关键片段 + +```yaml +compression: + name: QAT + quantization: + name: fp8_static + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] + QAT: + hf_dataset: null + from_ptq_ckpt: ./output_ptq_30b/qwen3-a3b_fp8_static + training_mode: end2end + dist_mode: hf + save_format: real + loss_type: cakld + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + bf16: true + per_device_train_batch_size: 1 + learning_rate: 1.0e-6 + gradient_checkpointing: true + deepspeed: configs/qwen3/qat/fp8_static/learn_scale/ds_config_zero3.json +``` + +完整示例: +- `configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml` +- `configs/qwen3/qat/fp8_static/learn_scale/qwen3-30b-a3b_fp8_static_end2end_learn_scale_zero3.yaml` + +## 核心实现要点 + +### `zero3_io.py` 中的主要 API + +| 函数 / 类 | 作用 | +|-----------|------| +| `is_deepspeed_zero3_enabled()` | 通过 HF 的 `HfTrainerDeepSpeedConfig` 弱引用判断是否已注册 ZeRO-3 配置 | +| `is_zero3_param(p)` | 判断 Parameter 是否带有 `ds_id / ds_status / ds_numel / ds_tensor` 元数据 | +| `gathered_param_if_zero3(p, modifier_rank=None)` | 上下文管理器;非 ZeRO-3 参数为 no-op | +| `LinearizedMoeExperts` | 通用空 per-expert `nn.Linear` 容器,`forward` 与 HF fused 等价 | +| `linearize_moe_experts_empty(model)` | duck-typing 扫描并原地替换,在 `deepspeed.zero.Init` 内构造以直接 partition | +| `zero3_empty_model_from_pretrained(model_path, ...)` | `no_init_weights` + `from_config` 构造空模型 + 自动拆 MoE | +| `stream_load_weights(model, model_path)` | 流式灌权,支持 fused MoE key 的 per-expert 切片分发 | +| `stream_load_scales(model, ckpt_dir)` | 流式灌 scale,支持 `.weight_scale / .input_scale / .k_cache.scale / .v_cache.scale` | +| `save_via_model_save_func(quant_model, save_func, path, prebuilt_state_dict)` | 只在 rank0 调用原 save_func,`state_dict()` 被 patch 返回合并后的字典 | +| `patch_deepspeed_duplicate_check()` | 置空 `DeepSpeedEngine._check_for_duplicates`,允许 tied scale 参数 | + +### `Quantizer` 的变更 + +- 新增 `weight_shape` 构造参数:`QuantLinear.__init__` 传入 `(out_features, in_features)`,使得 Quantizer 在不访问权重数据的情况下也能计算 scale 的形状。 +- 新增 `weight_scale_init_value` / `activation_scale_init_value` 配置项;`_init_quant_params` / `_init_lwc_params` 在 `_needs_external_weight_init(x)` 为真时使用。 +- `QuantLinear.forward` 末尾的 `F.linear` 现在包在 `torch.amp.autocast(device_type="cuda", enabled=False)` 中,并在调用前把 `input.dtype` 对齐到 `weight.dtype`,以避免 DeepSpeed `zero3_linear_wrap` 的 autocast 把 bf16 input 回转成 fp16。 +- `fake_quant` 末尾把 `out` cast 回 `x.dtype`,防止 `bf16 * fp32 = fp32` 的 dtype 泄漏。 + +### 优化器去重 + +`End2EndTrainer._init_optimizer` 使用 `_unique_named_params(...)` 根据 `id()` 去重收集 trainable 的 scale / LWC 参数,避免同一 Parameter 被多次加入 param group(在 MoE expert 共享 tensor 的场景下会触发)。配合 `patch_deepspeed_duplicate_check()` 即可通过 DeepSpeed 的安全检查。 + +### `convert` + `save` 的内存控制 + +`QAT.convert` 在 ZeRO-3 路径下: + +1. 对每个 `QuantLinear`:**所有 rank** 都进入 `gathered_param_if_zero3` 拿到完整 weight(NCCL collective 保持对称),但**只 rank0** 保留 CPU clone。 +2. rank0 在临时 `QDQModule` 内部跑一次 fp8/int 量化,把 `weight / weight_scale / input_scale / bias` 取出塞进 `self._rank0_state_dict`,随后丢弃临时模块。 +3. **不修改模型结构**:保持 `QuantLinear`,使得第二轮扫描(收集非 QuantLinear 参数,如 embed、lm_head、layernorm、MoE router gate)在所有 rank 上 `named_parameters` 顺序一致,collective gather 不会死锁。 +4. `QAT.save` 把 `_rank0_state_dict` 透传给 `save_via_model_save_func`,后者 patch `hf_model.state_dict`,**只 rank0** 调原 `save_func.save(...)`。 + +rank>0 convert 阶段 CPU 峰值 ≈ 一层的完整 weight(几十 MB 到 GB 量级);rank0 峰值 ≈ 累积的合并 state_dict(完整模型大小)。 + +## 已验证场景 + +| 场景 | 模型 | 硬件 | 结果 | +|------|------|------|------| +| Dense ZeRO-3 QAT | Qwen3-4B | 2×H20 | ✓ PTQ→QAT→save 打通,产物能被 transformers 加载 | +| MoE ZeRO-3 QAT | Qwen3-30B-A3B(48 层 × 128 experts)| 8×H20 | ✓ `stream_load_scales` 命中 37248 个 scale,训练 loss 稳定,输出 31 GB fp8 checkpoint | +| KD + LM loss 组合 | 同上 | 同上 | ✓ `lm_loss / kd/cakld / kd/forward_kl / kd/backward_kl / total_loss` 按权重打印 | +| 最后 assistant 仅监督 | TextDataset(jsonl) | - | ✓ 首个 valid label idx 落在 `<\|im_start\|>assistant\n` 之后 | +| 非 ZeRO-3 回归 | Qwen3-4B 单卡 PTQ | 1×H20 | ✓ 行为与 main 一致,无回归 | + +## 代码规范 / 提交须知 + +根据公司对外开源代码格式规范,本次改动遵循: + +1. **所有新增 / 修改代码无中文注释**。`git diff main..HEAD -- '*.py' '*.yaml' '*.json' '*.sh'` 输出中不包含中文字符。 +2. **所有 yaml / sh / py 内置路径均为相对路径或占位符**。`git diff main..HEAD` 中不包含 `apdcephfs` / 内部挂载点 / 任何账号或密码。 +3. **遵守仓库 pre-commit 规范**(black line-length 99 + isort black profile + flake8)。提交前执行: + + ```bash + pre-commit install # 一次性安装钩子 + pre-commit run --all-files # 或 git commit 时自动触发 + ``` + + 本次所有改动均已通过 Black / isort / Flake8。 + +## 目前的限制 + +- Blockwise 训练路径未针对 ZeRO-3 适配。ZeRO-3 分支只在 `training_mode: end2end` + `dist_mode: hf` 下工作。 +- 仅支持 `save_format` 为 `real` / `real_and_kvcache` / `save_kvcache_only`;`fake` 格式依赖 `trainer.external_trainer.model.state_dict()`,在 ZeRO-3 下只产出 rank 本地切片,未做合并。 +- `loss_type` 新增的 `cakld` 默认 reduction 为 mean;如需按 batch token 归一化的其它 reduction,请在 `_compute_kd_components` 中自行扩展。 diff --git a/scripts/qat/test_moe_zero3_build.py b/scripts/qat/test_moe_zero3_build.py deleted file mode 100644 index eb18f64f..00000000 --- a/scripts/qat/test_moe_zero3_build.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Standalone smoke test: under ZeRO-3, build an empty Qwen3-30B-A3B model -with layers trimmed to 2, linearize MoE, stream weights, and verify the -per-expert Linear shapes are correct. - -Run: - torchrun --nproc_per_node=2 scripts/qat/test_moe_zero3_build.py -""" -import os -import sys -import torch -import torch.distributed as dist - -# Initialise torch distributed first (required by deepspeed.zero.Init). -dist.init_process_group(backend="nccl") -torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - -# Register HfTrainerDeepSpeedConfig so is_deepspeed_zero3_enabled() returns True. -from transformers import Seq2SeqTrainingArguments # noqa: E402 - -ds_config = { - "zero_optimization": {"stage": 3, "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9}, - "train_batch_size": 2, - "train_micro_batch_size_per_gpu": 1, - "gradient_accumulation_steps": 1, - "bf16": {"enabled": True}, -} -_hf_args = Seq2SeqTrainingArguments( - output_dir="./tmp_out", deepspeed=ds_config, bf16=True, - per_device_train_batch_size=1, -) - -from transformers import AutoConfig # noqa: E402 -from angelslim.utils import ( # noqa: E402 - is_deepspeed_zero3_enabled, is_zero3_param, - zero3_empty_model_from_pretrained, stream_load_weights, linearize_moe_experts_empty, -) - -assert is_deepspeed_zero3_enabled(), "HF ZeRO-3 not registered" - -MODEL_PATH = "/apdcephfs_zwfy2/share_301053287/brunosu/all_models/Qwen3-30B-A3B" - -# Trim layers to 2 for quick iteration. -cfg = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) -cfg.num_hidden_layers = 2 -cfg.use_cache = False - -# Build empty model from the trimmed config. -from transformers import AutoModelForCausalLM # noqa: E402 -from transformers.initialization import no_init_weights, no_tie_weights # noqa: E402 - -with no_init_weights(), no_tie_weights(): - model = AutoModelForCausalLM.from_config(cfg, torch_dtype=torch.bfloat16, trust_remote_code=True) - -if dist.get_rank() == 0: - print("Built empty model:", type(model).__name__) - print("num layers:", len(model.model.layers)) - # Inspect first layer MoE before linearization - layer = model.model.layers[0] - mlp = layer.mlp - print("mlp type:", type(mlp).__name__) - if hasattr(mlp, "experts"): - experts = mlp.experts - print("experts type:", type(experts).__name__) - print("has gate_up_proj:", hasattr(experts, "gate_up_proj")) - if hasattr(experts, "gate_up_proj"): - print("gate_up_proj ds_shape:", getattr(experts.gate_up_proj, "ds_shape", None)) - print("is_zero3:", is_zero3_param(experts.gate_up_proj)) - -replaced = linearize_moe_experts_empty(model, dtype=torch.bfloat16) -dist.barrier() - -if dist.get_rank() == 0: - print(f"Replaced {replaced} fused experts") - layer = model.model.layers[0] - mlp = layer.mlp - experts = mlp.experts - print("After linearize - experts type:", type(experts).__name__) - print("num_experts attr:", experts.num_experts) - # Inspect one expert - e0 = experts[0] - print("expert 0:", type(e0).__name__) - print(" gate_proj weight shape:", e0["gate_proj"].weight.shape, - "is_zero3:", is_zero3_param(e0["gate_proj"].weight)) - print(" gate_proj ds_shape:", getattr(e0["gate_proj"].weight, "ds_shape", None)) - -# Now stream load -stream_load_weights(model, MODEL_PATH, log_prefix=f"[rank{dist.get_rank()}]") -dist.barrier() - -# Note: because we trimmed to 2 layers, the checkpoint has weights for -# layers [0..47], so layers [2..47] will appear as "unused keys" in the -# missing-key summary (not an error). - -if dist.get_rank() == 0: - print("\n=== Build + linearize + stream_load_weights OK ===") - -dist.destroy_process_group() diff --git a/tools/run.py b/tools/run.py index be84bf4c..c15924c2 100644 --- a/tools/run.py +++ b/tools/run.py @@ -289,9 +289,7 @@ def _prewarm_hf_deepspeed_config(config): output_dir=config.global_config.save_path, **hf_args, ) - print_info( - "[DeepSpeed pre-warm] HfTrainerDeepSpeedConfig registered before model load." - ) + print_info("[DeepSpeed pre-warm] HfTrainerDeepSpeedConfig registered before model load.") return trainer_args From 6e781978f56529abc4a6055fc32090e1683f8b3f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Apr 2026 13:27:05 +0800 Subject: [PATCH 04/13] update document --- .gitignore | 2 ++ .../source/features/quantization/qat_zero3.md | 20 ------------------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 00e37b2b..d8fcb3d2 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ eval/ *_ckpt*/ output/ outputs/ +output*/ +logs*/ outs/ wandb/ tools/results/ diff --git a/docs/source/features/quantization/qat_zero3.md b/docs/source/features/quantization/qat_zero3.md index ae648b54..49c8e2ff 100644 --- a/docs/source/features/quantization/qat_zero3.md +++ b/docs/source/features/quantization/qat_zero3.md @@ -321,23 +321,3 @@ rank>0 convert 阶段 CPU 峰值 ≈ 一层的完整 weight(几十 MB 到 GB | 最后 assistant 仅监督 | TextDataset(jsonl) | - | ✓ 首个 valid label idx 落在 `<\|im_start\|>assistant\n` 之后 | | 非 ZeRO-3 回归 | Qwen3-4B 单卡 PTQ | 1×H20 | ✓ 行为与 main 一致,无回归 | -## 代码规范 / 提交须知 - -根据公司对外开源代码格式规范,本次改动遵循: - -1. **所有新增 / 修改代码无中文注释**。`git diff main..HEAD -- '*.py' '*.yaml' '*.json' '*.sh'` 输出中不包含中文字符。 -2. **所有 yaml / sh / py 内置路径均为相对路径或占位符**。`git diff main..HEAD` 中不包含 `apdcephfs` / 内部挂载点 / 任何账号或密码。 -3. **遵守仓库 pre-commit 规范**(black line-length 99 + isort black profile + flake8)。提交前执行: - - ```bash - pre-commit install # 一次性安装钩子 - pre-commit run --all-files # 或 git commit 时自动触发 - ``` - - 本次所有改动均已通过 Black / isort / Flake8。 - -## 目前的限制 - -- Blockwise 训练路径未针对 ZeRO-3 适配。ZeRO-3 分支只在 `training_mode: end2end` + `dist_mode: hf` 下工作。 -- 仅支持 `save_format` 为 `real` / `real_and_kvcache` / `save_kvcache_only`;`fake` 格式依赖 `trainer.external_trainer.model.state_dict()`,在 ZeRO-3 下只产出 rank 本地切片,未做合并。 -- `loss_type` 新增的 `cakld` 默认 reduction 为 mean;如需按 batch token 归一化的其它 reduction,请在 `_compute_kd_components` 中自行扩展。 From c9b3d87f9d471a1b736ae488d01883fc83d25804 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Apr 2026 13:44:28 +0800 Subject: [PATCH 05/13] update document --- angelslim/compressor/qat/modules/quantizer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py index f7754609..db1d203d 100644 --- a/angelslim/compressor/qat/modules/quantizer.py +++ b/angelslim/compressor/qat/modules/quantizer.py @@ -641,13 +641,7 @@ def forward(self, input: torch.Tensor): # autocast) may have cast ``input`` to fp16 even though we run in # bf16. Align to the (fake-quantised) weight dtype so F.linear # stays consistent. - if input.dtype != weight.dtype: - input = input.to(weight.dtype) - # Disable autocast around the matmul so the DeepSpeed - # zero3_linear_wrap wrapper (decorated with ``autocast_custom_fwd``) - # does NOT silently re-cast ``input`` back to the autocast dtype. - with torch.amp.autocast(device_type="cuda", enabled=False): - output = self.fwd_func(input, weight, self.bias) + output = self.fwd_func(input.to(self.weight.dtype), weight.to(self.weight.dtype), self.bias) if self.use_qkv_quant: output = self.qkv_quantizer(output) return output From 78b0ca42d422d8d3f4688304687374b5554f01e6 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Apr 2026 20:03:12 +0800 Subject: [PATCH 06/13] add zero3 qat for hunyuan v3 --- angelslim/models/llm/hunyuan_v3_moe.py | 61 +++++++++++----- angelslim/utils/zero3_io.py | 15 +++- .../fp8_static/hunyuanv3_a20b_fp8_static.yaml | 38 ++++++++++ .../learn_scale/ds_config_zero3.json | 45 ++++++++++++ ..._fp8_static_end2end_learn_scale_zero3.yaml | 72 +++++++++++++++++++ .../qat/run_qat_for_hunyuanv3_a20b_zero3.sh | 11 +++ 6 files changed, 223 insertions(+), 19 deletions(-) create mode 100644 configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml create mode 100644 configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json create mode 100644 configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml create mode 100755 scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh diff --git a/angelslim/models/llm/hunyuan_v3_moe.py b/angelslim/models/llm/hunyuan_v3_moe.py index d8919643..fdfd9c5c 100644 --- a/angelslim/models/llm/hunyuan_v3_moe.py +++ b/angelslim/models/llm/hunyuan_v3_moe.py @@ -16,20 +16,50 @@ import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.hy_v3.modeling_hy_v3 import ( ALL_ATTENTION_FUNCTIONS, HYV3Experts, + HYV3TopKRouter, apply_rotary_pos_emb, eager_attention_forward, ) from ...compressor.quant.core import PTQSaveVllmHF +from ...utils import is_deepspeed_zero3_enabled from ...utils.utils import find_layers, find_parent_layer_and_sub_name from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory +def _patch_hyv3_router_for_zero3(): + if getattr(HYV3TopKRouter, "_angelslim_zero3_dtype_patch", False): + return + + def patched_forward( + self, + hidden_states: torch.Tensor, + e_score_correction_bias: torch.Tensor, + ) -> tuple: + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = nn.functional.linear( + hidden_states.to(self.weight.dtype), + self.weight, + ).to(torch.float32) + routing_weights = torch.sigmoid(router_logits) + + scores_for_choice = routing_weights + e_score_correction_bias + _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False) + top_k_weights = routing_weights.gather(1, top_k_index) + + top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20) + top_k_weights = top_k_weights * self.router_scaling_factor + + return router_logits, top_k_weights, top_k_index + + HYV3TopKRouter.forward = patched_forward + HYV3TopKRouter._angelslim_zero3_dtype_patch = True + + class HYV3ExpertsWithLinear(HYV3Experts): """Wrapper around HYV3Experts that exposes per-expert weights as nn.Linear modules. @@ -138,19 +168,18 @@ def from_pretrained( ): attn_implementation = "eager" torch_dtype = torch.bfloat16 - self.model = AutoModelForCausalLM.from_pretrained( - model_path, - attn_implementation=attn_implementation, + if is_deepspeed_zero3_enabled(): + _patch_hyv3_router_for_zero3() + + super().from_pretrained( + model_path=model_path, torch_dtype=torch_dtype, device_map=device_map, trust_remote_code=trust_remote_code, low_cpu_mem_usage=low_cpu_mem_usage, use_cache=use_cache, - ) - - # Load tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=trust_remote_code + using_multi_nodes=using_multi_nodes, + attn_implementation=attn_implementation, ) def replace_moe(self): @@ -179,9 +208,9 @@ def get_observer_layers(self): "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj", - "shared_mlp.gate_proj", - "shared_mlp.up_proj", - "shared_mlp.down_proj", + "shared_experts.gate_proj", + "shared_experts.up_proj", + "shared_experts.down_proj", ] expert_pattern = [ r"model\.layers\.\d+\.mlp\.experts\.\d+\.gate_proj", @@ -326,11 +355,9 @@ def patched_forward( key_states, value_states, attn_module.layer_idx, cache_kwargs ) - attention_interface = eager_attention_forward - if attn_module.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[ - attn_module.config._attn_implementation - ] + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + attn_module.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( attn_module, diff --git a/angelslim/utils/zero3_io.py b/angelslim/utils/zero3_io.py index 455d4433..614ee5bd 100644 --- a/angelslim/utils/zero3_io.py +++ b/angelslim/utils/zero3_io.py @@ -211,7 +211,18 @@ def forward( with torch.no_grad(): expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + expert_hit_mask = torch.greater(expert_mask.sum(dim=(-1, -2)), 0) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + # ZeRO-3 gathers parameters when a Linear is entered. All ranks must + # enter the same experts in the same order even when their local + # batches route to different experts, otherwise collectives deadlock. + expert_hit_int = expert_hit_mask.to(torch.int32) + torch.distributed.all_reduce( + expert_hit_int, + op=torch.distributed.ReduceOp.MAX, + ) + expert_hit_mask = expert_hit_int.to(torch.bool) + expert_hit = expert_hit_mask.nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] @@ -219,7 +230,7 @@ def forward( continue top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] - expert_layer = getattr(self, str(int(expert_idx))) + expert_layer = getattr(self, str(int(expert_idx.item()))) gate = expert_layer["gate_proj"](current_state) up = expert_layer["up_proj"](current_state) current_hidden_states = self.act_fn(gate) * up diff --git a/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml b/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml new file mode 100644 index 00000000..06558ffd --- /dev/null +++ b/configs/hunyuan/ptq/fp8_static/hunyuanv3_a20b_fp8_static.yaml @@ -0,0 +1,38 @@ +# Global configuration of pipeline +global: + save_path: ./output_ptq_hy3 + +# Simplified Configuration for LLM compression +model: + name: HYV3MoE + model_path: tencent/Hy3-preview + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + kv_cache: "per-tensor" + ignore_layers: + - "lm_head" + - "model.embed_tokens" + - "gate.weight" + cpu_convert: true + save_name: "fp8" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4/sharegpt_gpt4_256.jsonl + max_seq_length: 2048 + num_samples: 512 + batch_size: 1 diff --git a/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json b/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json new file mode 100644 index 00000000..e513dd7b --- /dev/null +++ b/configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json @@ -0,0 +1,45 @@ +{ + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": "auto", + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none" + }, + "offload_param": { + "device": "none" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml new file mode 100644 index 00000000..d110e90d --- /dev/null +++ b/configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml @@ -0,0 +1,72 @@ +global: + save_path: ./output_hy3_zero3 + +model: + name: HYV3MoE + model_path: tencent/Hy3-preview + trust_remote_code: true + torch_dtype: auto + device_map: None + low_cpu_mem_usage: true + use_cache: false + +compression: + name: QAT + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: per-tensor + activation: per-tensor + ignore_layers: ["lm_head", "embed_tokens", "gate.weight"] + QAT: + hf_dataset: null + from_ptq_ckpt: ./output_ptq_hy3/hunyuanv3_a20b_fp8_static + training_mode: end2end + dist_mode: hf + save_format: real + do_train: true + loss_type: cakld + loss_topk: null + kd_temperature: 1.0 + kd_alpha: 0.5 + lm_loss_weight: 1.0 + kd_loss_weight: 1.0 + plugin_config: + enable_scale: true + quant_config: + use_weight_quant: true + use_activation_quant: true + use_qkv_quant: false + weight_scale_init_value: 0.1 + activation_scale_init_value: 0.1 + learnable: + act_scale: false + weight_scale: true + kv_scale: false + norm: false + lwc: false + hf_args: + bf16: true + logging_steps: 1 + logging_first_step: true + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 1.0e-6 + lr_scheduler_type: cosine + num_train_epochs: 1 + max_steps: 3 + save_strategy: "no" + warmup_ratio: 0.0 + max_grad_norm: 1.0 + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: true + deepspeed: configs/hunyuan/qat/fp8_static/learn_scale/ds_config_zero3.json + +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4/sharegpt_gpt4_256.jsonl + max_seq_length: 2048 + num_samples: 256 + batch_size: 1 diff --git a/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh b/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh new file mode 100755 index 00000000..e5b2d1c3 --- /dev/null +++ b/scripts/qat/run_qat_for_hunyuanv3_a20b_zero3.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +export PYTORCH_ALLOC_CONF="expandable_segments:True" + +NPROC=${NPROC:-8} +CONFIG=${CONFIG:-configs/hunyuan/qat/fp8_static/learn_scale/hunyuanv3_a20b_fp8_static_end2end_learn_scale_zero3.yaml} + +torchrun --nproc_per_node=${NPROC} \ + tools/run.py \ + -c "${CONFIG}" From c9ed4530ac48625fa95155a27e40a9202cb3105d Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Apr 2026 20:37:03 +0800 Subject: [PATCH 07/13] fix pre-commit --- angelslim/compressor/qat/modules/quantizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/angelslim/compressor/qat/modules/quantizer.py b/angelslim/compressor/qat/modules/quantizer.py index db1d203d..4eb627e4 100644 --- a/angelslim/compressor/qat/modules/quantizer.py +++ b/angelslim/compressor/qat/modules/quantizer.py @@ -641,7 +641,9 @@ def forward(self, input: torch.Tensor): # autocast) may have cast ``input`` to fp16 even though we run in # bf16. Align to the (fake-quantised) weight dtype so F.linear # stays consistent. - output = self.fwd_func(input.to(self.weight.dtype), weight.to(self.weight.dtype), self.bias) + output = self.fwd_func( + input.to(self.weight.dtype), weight.to(self.weight.dtype), self.bias + ) if self.use_qkv_quant: output = self.qkv_quantizer(output) return output From ac977ab1261a9ae839729192142faa814948b861 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Apr 2026 11:20:13 +0800 Subject: [PATCH 08/13] update --- .../compressor/qat/plugins/distill_loss.py | 117 ++++++++++++++++++ .../qat/trainers/end2end_trainer.py | 111 ++--------------- angelslim/utils/__init__.py | 15 +-- 3 files changed, 126 insertions(+), 117 deletions(-) create mode 100644 angelslim/compressor/qat/plugins/distill_loss.py diff --git a/angelslim/compressor/qat/plugins/distill_loss.py b/angelslim/compressor/qat/plugins/distill_loss.py new file mode 100644 index 00000000..4024b3e4 --- /dev/null +++ b/angelslim/compressor/qat/plugins/distill_loss.py @@ -0,0 +1,117 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +class DistillLoss: + def __init__(self, loss_type="kl", loss_topk=None, kd_temperature=1.0): + self.loss_type = str(loss_type).lower() + self.loss_topk = loss_topk + self.kd_temperature = float(kd_temperature) + + @staticmethod + def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor: + """Per-token KL(tgt || src) (shape [N]).""" + return F.kl_div(log_p_src, p_tgt, reduction="none").sum(dim=-1) + + def compute(self, student_logits, teacher_logits, labels): + """Return per-token KD losses computed only on valid labels. + + The returned dict always contains ``loss``, ``forward_kl`` and + ``backward_kl`` so callers can log diagnostics for every KD variant. + """ + flat_mask = (labels != -100).reshape(-1) + if flat_mask.sum() == 0: + zero = student_logits.new_zeros(()) + return {"loss": zero, "forward_kl": zero, "backward_kl": zero} + + # Flatten to [N, V] and keep only valid tokens. + s_flat = student_logits.flatten(0, -2)[flat_mask] + t_flat = teacher_logits.flatten(0, -2)[flat_mask] + valid_labels = labels.reshape(-1)[flat_mask] + + # Diagnostic KL (always computed at T=1 on valid tokens). + s_logp = F.log_softmax(s_flat, dim=-1) + t_logp = F.log_softmax(t_flat, dim=-1) + s_p = s_logp.exp() + t_p = t_logp.exp() + forward_kl = self._kl_per_token(s_logp, t_p).mean() + backward_kl = self._kl_per_token(t_logp, s_p).mean() + + # Main KD loss according to loss_type. + if self.loss_type == "kl": + kd = forward_kl + elif self.loss_type == "rkl": + kd = backward_kl + elif self.loss_type == "mse": + kd = F.mse_loss(s_flat, t_flat) + elif self.loss_type == "kd": + # Legacy "kd": temperature-scaled forward KL. The outer trainer + # combines this value with the LM loss by configured weights. + temperature = max(self.kd_temperature, 1e-6) + kd = self._kl_per_token( + F.log_softmax(s_flat / temperature, dim=-1), + F.softmax(t_flat / temperature, dim=-1), + ).mean() * (temperature * temperature) + elif self.loss_type == "cakld": + # Contextual Asymmetric KL-divergence: per-token mixing of + # forward / reverse KL by teacher's confidence on the label. + per_tok_fkl = self._kl_per_token(s_logp, t_p) + per_tok_bkl = self._kl_per_token(t_logp, s_p) + conf = torch.gather(t_p, dim=-1, index=valid_labels.unsqueeze(-1)).squeeze(-1) + kd = (conf * per_tok_bkl + (1.0 - conf) * per_tok_fkl).mean() + elif self._is_reverse_topk_loss(self.loss_type): + topk = self._resolve_topk(self.loss_type, s_flat.size(-1)) + top_s, idx = s_flat.topk(topk, dim=-1, sorted=False) + top_t = t_flat.gather(-1, idx) + kd = self._kl_per_token( + F.log_softmax(top_t, dim=-1), + F.softmax(top_s, dim=-1), + ).mean() + elif self._is_forward_topk_loss(self.loss_type): + topk = self._resolve_topk(self.loss_type, t_flat.size(-1)) + top_t, idx = t_flat.topk(topk, dim=-1, sorted=False) + top_s = s_flat.gather(-1, idx) + kd = self._kl_per_token( + F.log_softmax(top_s, dim=-1), + F.softmax(top_t, dim=-1), + ).mean() + else: + raise ValueError( + f"Unsupported QAT kd loss_type: {self.loss_type}. " + "Valid: kl, rkl, mse, kd, cakld, kl_top[_K], r_kl_top[_K]." + ) + + return {"loss": kd, "forward_kl": forward_kl, "backward_kl": backward_kl} + + @staticmethod + def _is_forward_topk_loss(loss_type): + return loss_type.startswith("kl_top") + + @staticmethod + def _is_reverse_topk_loss(loss_type): + return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top") + + def _resolve_topk(self, loss_type, vocab_size): + topk = self.loss_topk + if topk is None and "_top_" in loss_type: + topk = int(loss_type.rsplit("_", 1)[-1]) + if topk is None: + topk = 1000 + topk = int(topk) + if topk <= 0: + raise ValueError(f"loss_topk must be positive, got: {topk}") + return min(topk, vocab_size) diff --git a/angelslim/compressor/qat/trainers/end2end_trainer.py b/angelslim/compressor/qat/trainers/end2end_trainer.py index 79828981..57eac608 100644 --- a/angelslim/compressor/qat/trainers/end2end_trainer.py +++ b/angelslim/compressor/qat/trainers/end2end_trainer.py @@ -13,12 +13,12 @@ # limitations under the License. import torch -import torch.nn.functional as F from datasets import load_dataset from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments from ....data.qat_dataset import QATDataset from ....utils import patch_deepspeed_duplicate_check, print_info +from ..plugins.distill_loss import DistillLoss from ..plugins.learnable_scale import set_quant_state from .trainer_factory import TrainerFactory @@ -54,6 +54,11 @@ def __init__(self, *args, loss_config=None, quant_config=None, **kwargs): self.kd_alpha = float(loss_config.get("kd_alpha", 0.5)) self.lm_loss_weight = float(loss_config.get("lm_loss_weight", 1.0)) self.kd_loss_weight = float(loss_config.get("kd_loss_weight", 0.0)) + self.distill_loss = DistillLoss( + loss_type=self.loss_type, + loss_topk=self.loss_topk, + kd_temperature=self.kd_temperature, + ) self.use_weight_quant = quant_config.get("use_weight_quant", False) self.use_activation_quant = quant_config.get("use_activation_quant", False) self.use_qkv_quant = quant_config.get("use_qkv_quant", False) @@ -63,89 +68,6 @@ def __init__(self, *args, loss_config=None, quant_config=None, **kwargs): self._qat_metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - # ------------------------------------------------------------------ - # KD loss helpers - # ------------------------------------------------------------------ - - @staticmethod - def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor: - """Per-token KL(tgt || src) (shape [N]).""" - return F.kl_div(log_p_src, p_tgt, reduction="none").sum(dim=-1) - - def _compute_kd_components(self, student_logits, teacher_logits, labels): - """Return a dict of per-token KD losses computed only on valid - (label != -100) positions. Keys present depend on ``self.loss_type``. - - Always returns ``forward_kl`` and ``backward_kl`` (useful for - logging even when the main kd loss is e.g. ``mse`` or a topk - variant) — but only when kd_loss_weight > 0. - """ - flat_mask = (labels != -100).reshape(-1) - if flat_mask.sum() == 0: - zero = student_logits.new_zeros(()) - return {"loss": zero, "forward_kl": zero, "backward_kl": zero} - - # Flatten to [N, V] and keep only valid tokens. - s_flat = student_logits.flatten(0, -2)[flat_mask] - t_flat = teacher_logits.flatten(0, -2)[flat_mask] - valid_labels = labels.reshape(-1)[flat_mask] - - # Diagnostic KL (always computed at T=1 on valid tokens). - s_logp = F.log_softmax(s_flat, dim=-1) - t_logp = F.log_softmax(t_flat, dim=-1) - s_p = s_logp.exp() - t_p = t_logp.exp() - forward_kl = self._kl_per_token(s_logp, t_p).mean() - backward_kl = self._kl_per_token(t_logp, s_p).mean() - - # Main KD loss according to loss_type. - if self.loss_type == "kl": - kd = forward_kl - elif self.loss_type == "rkl": - kd = backward_kl - elif self.loss_type == "mse": - kd = F.mse_loss(s_flat, t_flat) - elif self.loss_type == "kd": - # Legacy "kd": temperature-scaled forward KL. Combined loss is - # (alpha*T^2)*KD + (1-alpha)*lm — but we now rely on - # lm/kd_loss_weight for the outer combination, so return just - # the scaled KL here. - T = max(self.kd_temperature, 1e-6) - kd = self._kl_per_token( - F.log_softmax(s_flat / T, dim=-1), - F.softmax(t_flat / T, dim=-1), - ).mean() * (T * T) - elif self.loss_type == "cakld": - # Contextual Asymmetric KL-divergence: per-token mixing of - # forward / reverse KL by teacher's confidence on the label. - per_tok_fkl = self._kl_per_token(s_logp, t_p) - per_tok_bkl = self._kl_per_token(t_logp, s_p) - conf = torch.gather(t_p, dim=-1, index=valid_labels.unsqueeze(-1)).squeeze(-1) # [N] - kd = (conf * per_tok_bkl + (1.0 - conf) * per_tok_fkl).mean() - elif self._is_reverse_topk_loss(self.loss_type): - topk = self._resolve_topk(self.loss_type, s_flat.size(-1)) - top_s, idx = s_flat.topk(topk, dim=-1, sorted=False) - top_t = t_flat.gather(-1, idx) - kd = self._kl_per_token( - F.log_softmax(top_t, dim=-1), - F.softmax(top_s, dim=-1), - ).mean() - elif self._is_forward_topk_loss(self.loss_type): - topk = self._resolve_topk(self.loss_type, t_flat.size(-1)) - top_t, idx = t_flat.topk(topk, dim=-1, sorted=False) - top_s = s_flat.gather(-1, idx) - kd = self._kl_per_token( - F.log_softmax(top_s, dim=-1), - F.softmax(top_t, dim=-1), - ).mean() - else: - raise ValueError( - f"Unsupported QAT kd loss_type: {self.loss_type}. " - "Valid: kl, rkl, mse, kd, cakld, kl_top[_K], r_kl_top[_K]." - ) - - return {"loss": kd, "forward_kl": forward_kl, "backward_kl": backward_kl} - def _record(self, name, value): if value is None: return @@ -195,7 +117,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if labels is None: raise ValueError("kd_loss_weight > 0 requires ``labels`` in the batch.") teacher_logits = self.get_ori_outputs(model, inputs).logits - kd_info = self._compute_kd_components( + kd_info = self.distill_loss.compute( outputs.logits, teacher_logits, labels, @@ -223,25 +145,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N return (total, outputs) if return_outputs else total - @staticmethod - def _is_forward_topk_loss(loss_type): - return loss_type.startswith("kl_top") - - @staticmethod - def _is_reverse_topk_loss(loss_type): - return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top") - - def _resolve_topk(self, loss_type, vocab_size): - topk = self.loss_topk - if topk is None and "_top_" in loss_type: - topk = int(loss_type.rsplit("_", 1)[-1]) - if topk is None: - topk = 1000 - topk = int(topk) - if topk <= 0: - raise ValueError(f"loss_topk must be positive, got: {topk}") - return min(topk, vocab_size) - @torch.no_grad() def get_ori_outputs(self, model, inputs): teacher_inputs = dict(inputs) diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index 78d2771e..f76ef2f9 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -30,16 +30,5 @@ from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 -from .zero3_io import LinearizedMoeExperts # noqa: F401 -from .zero3_io import consolidated_state_dict # noqa: F401 -from .zero3_io import gathered_param_if_zero3 # noqa: F401 -from .zero3_io import gathered_params_if_zero3 # noqa: F401 -from .zero3_io import is_deepspeed_zero3_enabled # noqa: F401 -from .zero3_io import is_zero3_param # noqa: F401 -from .zero3_io import linearize_moe_experts_empty # noqa: F401 -from .zero3_io import model_has_zero3_params # noqa: F401 -from .zero3_io import patch_deepspeed_duplicate_check # noqa: F401 -from .zero3_io import save_via_model_save_func # noqa: F401 -from .zero3_io import stream_load_scales # noqa: F401 -from .zero3_io import stream_load_weights # noqa: F401 -from .zero3_io import zero3_empty_model_from_pretrained # noqa: F401 +from .zero3_io import * # noqa: F401 + From 2fd29002ef8f6aa07c62f3638d926f6419c59361 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Apr 2026 11:36:23 +0800 Subject: [PATCH 09/13] update sft dataset configs --- angelslim/data/dataloader.py | 2 ++ angelslim/data/text_dataset.py | 5 ++++- angelslim/engine.py | 2 ++ angelslim/utils/config_parser.py | 1 + .../qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml | 1 + tools/run.py | 2 ++ 6 files changed, 12 insertions(+), 1 deletion(-) diff --git a/angelslim/data/dataloader.py b/angelslim/data/dataloader.py index a3dcc0a6..72488ce1 100644 --- a/angelslim/data/dataloader.py +++ b/angelslim/data/dataloader.py @@ -44,6 +44,7 @@ def create_data_loader( use_audio_in_video: bool = False, model_name: str = None, quantization_config: str = None, + is_sft_data: bool = False, ) -> DataLoader: """ Create appropriate DataLoader based on data source @@ -85,6 +86,7 @@ def create_data_loader( max_length=max_length, data_path=data_source, num_samples=num_samples, + is_sft_data=is_sft_data, ) elif data_type == "MultiModalDataset": dataset = MultiModalDataset( diff --git a/angelslim/data/text_dataset.py b/angelslim/data/text_dataset.py index 67da87fd..96e5a874 100644 --- a/angelslim/data/text_dataset.py +++ b/angelslim/data/text_dataset.py @@ -34,8 +34,10 @@ def __init__( device: str = "cpu", max_length: int = 4096, num_samples: int = -1, + is_sft_data: bool = False, ): super().__init__(processor, device, max_length) + self.is_sft_data = is_sft_data self._load_data(data_path, num_samples) def _load_data(self, data_path: str, num_samples: int): @@ -195,7 +197,8 @@ def _load_jsonl_data(self, data_path: str, num_samples: int): input_ids = model_inputs["input_ids"] attention_mask = model_inputs["attention_mask"] labels = input_ids.clone() - labels[:, :prompt_len] = -100 + if self.is_sft_data: + labels[:, :prompt_len] = -100 # Also mask padding tokens. labels[attention_mask == 0] = -100 diff --git a/angelslim/engine.py b/angelslim/engine.py index 7feee90f..8c5344c2 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -159,6 +159,7 @@ def prepare_data( use_audio_in_video=False, model_name=None, quantization_config=None, + is_sft_data=False, ) -> Optional[Any]: """Prepare compression dataset""" if custom_dataloader is not None: @@ -185,6 +186,7 @@ def prepare_data( use_audio_in_video=use_audio_in_video, model_name=model_name, quantization_config=quantization_config, + is_sft_data=is_sft_data, ) self.max_seq_length = max_length diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index c7aad100..31b5acdc 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -185,6 +185,7 @@ class DatasetConfig: batch_size: int = field(default=1) shuffle: bool = field(default=False) inference_settings: Optional[Dict[str, Any]] = field(default=None) + is_sft_data: bool = field(default=False) @dataclass diff --git a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml index 27109b26..bb0822b2 100644 --- a/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml +++ b/configs/qwen3/qat/fp8_static/learn_scale/qwen3-4b_fp8_static_end2end_learn_scale_zero3.yaml @@ -70,3 +70,4 @@ dataset: max_seq_length: 512 num_samples: 16 batch_size: 1 + is_sft_data: true diff --git a/tools/run.py b/tools/run.py index c15924c2..88d84451 100644 --- a/tools/run.py +++ b/tools/run.py @@ -112,6 +112,7 @@ def multi_nodes_run(config): shuffle=dataset_config.shuffle, inference_settings=dataset_config.inference_settings, use_audio_in_video=model_config.use_audio_in_video, + is_sft_data=dataset_config.is_sft_data, ) # Step 6: Initialize compressor @@ -362,6 +363,7 @@ def run(config): use_audio_in_video=model_config.use_audio_in_video, model_name=model_config.name, quantization_config=compress_config.quantization, + is_sft_data=dataset_config.is_sft_data, ) # Step 5: Initialize compressor From 4f4fd9c2f842324f8f259b27e8fbe0f69394124b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Apr 2026 13:38:42 +0800 Subject: [PATCH 10/13] fix pre-commit --- .../sparsity/stem/modules/forward.py.isorted | 225 ++++++++++++++++++ angelslim/utils/__init__.py | 1 - 2 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 angelslim/compressor/sparsity/stem/modules/forward.py.isorted diff --git a/angelslim/compressor/sparsity/stem/modules/forward.py.isorted b/angelslim/compressor/sparsity/stem/modules/forward.py.isorted new file mode 100644 index 00000000..061ce8e3 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/modules/forward.py.isorted @@ -0,0 +1,225 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Stem-patched attention forward pass. + +This module provides the replacement ``forward`` method that is bound to each +attention layer by :func:`stem.patch.stem_patch`. During **prefill** +(``q_len > 1``) it delegates to the Stem sparse backend; during **decode** +(``q_len == 1``) it falls back to the model's original attention implementation +(eager, FlashAttention-2, SDPA, etc.). + +The code mirrors the structure of +``transformers.models.qwen3.modeling_qwen3.Qwen3Attention.forward`` +(Transformers >= 5.2) and should be kept in sync with upstream changes. +""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.processing_utils import Unpack + +from ..backends import stem_forward + +# --------------------------------------------------------------------------- +# Helper functions (identical to upstream Qwen3) +# --------------------------------------------------------------------------- + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate the last dimension by splitting and concatenating halves.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply Rotary Position Embedding (RoPE) to query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """Repeat KV heads to match the number of query heads (GQA support). + + ``(B, num_kv_heads, L, D)`` -> ``(B, num_attention_heads, L, D)`` + """ + if n_rep == 1: + return hidden_states + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# --------------------------------------------------------------------------- +# Fallback eager attention (used in decode phase, mirrors upstream) +# --------------------------------------------------------------------------- + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """Eager (non-sparse) scaled dot-product attention. + + Used as the **decode** fallback when ``q_len == 1`` and no specialised + attention implementation (e.g. FlashAttention-2) is configured. + Matches the upstream ``eager_attention_forward`` in Transformers >= 5.2. + """ + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def _assert_no_padding_mask_for_stem(attention_mask: torch.Tensor, k_len: int) -> None: + """Verify that the attention mask has no padding (required by Stem prefill). + + Raises + ------ + ValueError + If the mask is not 4-D or if the last query row contains ``-inf`` + entries (indicating padding tokens). + """ + if attention_mask.ndim != 4: + raise ValueError(f"attention_mask must be 4-D, got shape={tuple(attention_mask.shape)}") + last_row = attention_mask[:, :, -1, :k_len] + if not torch.isfinite(last_row).all(): + raise ValueError("Stem prefill requires no padding mask (last query row has -inf).") + + +# --------------------------------------------------------------------------- +# Patched attention forward +# --------------------------------------------------------------------------- + + +def attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Stem-patched attention forward — drop-in replacement for + ``Qwen3Attention.forward`` (Transformers >= 5.2). + + * **Prefill** (``q_len > 1``): delegates to :func:`stem_forward` which + computes block-sparse attention according to the configured backend. + * **Decode** (``q_len == 1``): uses the model's original attention + implementation (eager / FlashAttention-2 / SDPA / flex). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # --- QKV projection & RoPE (identical to upstream) -------------------- + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # --- KV cache update (Transformers >= 5.2 style) ---------------------- + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + q_len = query_states.shape[2] + k_len = key_states.shape[2] + + # --- Prefill (Stem sparse attention) ---------------------------------- + if q_len > 1: + if attention_mask is not None: + _assert_no_padding_mask_for_stem(attention_mask, k_len) + + prefill_kwargs = { + "layer_idx": self.layer_idx, + "attn_forward_config": self.attn_forward_config, + } + backend = self.attn_forward_config.get("backend", "torch") + + # HPC kernels (both bf16 and fp8) handle GQA internally; + # only the pure-torch path needs explicit KV head repeat. + if backend == "hpc": + stem_key_states = key_states + stem_value_states = value_states + else: + stem_key_states = repeat_kv(key_states, self.num_key_value_groups) + stem_value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = stem_forward( + query_states, stem_key_states, stem_value_states, prefill_kwargs + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_weights = None + + # --- Decode (standard attention, mirrors upstream) --------------------- + else: + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index f76ef2f9..6df56425 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -31,4 +31,3 @@ from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 from .zero3_io import * # noqa: F401 - From 2e4c45a7a22e6c43a5453e94881929d246b0d2e8 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Apr 2026 13:40:17 +0800 Subject: [PATCH 11/13] fix pre-commit --- .../sparsity/stem/modules/forward.py.isorted | 225 ------------------ 1 file changed, 225 deletions(-) delete mode 100644 angelslim/compressor/sparsity/stem/modules/forward.py.isorted diff --git a/angelslim/compressor/sparsity/stem/modules/forward.py.isorted b/angelslim/compressor/sparsity/stem/modules/forward.py.isorted deleted file mode 100644 index 061ce8e3..00000000 --- a/angelslim/compressor/sparsity/stem/modules/forward.py.isorted +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2025 Tencent Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Stem-patched attention forward pass. - -This module provides the replacement ``forward`` method that is bound to each -attention layer by :func:`stem.patch.stem_patch`. During **prefill** -(``q_len > 1``) it delegates to the Stem sparse backend; during **decode** -(``q_len == 1``) it falls back to the model's original attention implementation -(eager, FlashAttention-2, SDPA, etc.). - -The code mirrors the structure of -``transformers.models.qwen3.modeling_qwen3.Qwen3Attention.forward`` -(Transformers >= 5.2) and should be kept in sync with upstream changes. -""" - -from __future__ import annotations - -from collections.abc import Callable - -import torch -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.processing_utils import Unpack - -from ..backends import stem_forward - -# --------------------------------------------------------------------------- -# Helper functions (identical to upstream Qwen3) -# --------------------------------------------------------------------------- - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotate the last dimension by splitting and concatenating halves.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - unsqueeze_dim: int = 1, -) -> tuple[torch.Tensor, torch.Tensor]: - """Apply Rotary Position Embedding (RoPE) to query and key tensors.""" - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """Repeat KV heads to match the number of query heads (GQA support). - - ``(B, num_kv_heads, L, D)`` -> ``(B, num_attention_heads, L, D)`` - """ - if n_rep == 1: - return hidden_states - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -# --------------------------------------------------------------------------- -# Fallback eager attention (used in decode phase, mirrors upstream) -# --------------------------------------------------------------------------- - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - """Eager (non-sparse) scaled dot-product attention. - - Used as the **decode** fallback when ``q_len == 1`` and no specialised - attention implementation (e.g. FlashAttention-2) is configured. - Matches the upstream ``eager_attention_forward`` in Transformers >= 5.2. - """ - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- - - -def _assert_no_padding_mask_for_stem(attention_mask: torch.Tensor, k_len: int) -> None: - """Verify that the attention mask has no padding (required by Stem prefill). - - Raises - ------ - ValueError - If the mask is not 4-D or if the last query row contains ``-inf`` - entries (indicating padding tokens). - """ - if attention_mask.ndim != 4: - raise ValueError(f"attention_mask must be 4-D, got shape={tuple(attention_mask.shape)}") - last_row = attention_mask[:, :, -1, :k_len] - if not torch.isfinite(last_row).all(): - raise ValueError("Stem prefill requires no padding mask (last query row has -inf).") - - -# --------------------------------------------------------------------------- -# Patched attention forward -# --------------------------------------------------------------------------- - - -def attn_forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_values: Cache | None = None, - **kwargs: Unpack[FlashAttentionKwargs], -) -> tuple[torch.Tensor, torch.Tensor | None]: - """Stem-patched attention forward — drop-in replacement for - ``Qwen3Attention.forward`` (Transformers >= 5.2). - - * **Prefill** (``q_len > 1``): delegates to :func:`stem_forward` which - computes block-sparse attention according to the configured backend. - * **Decode** (``q_len == 1``): uses the model's original attention - implementation (eager / FlashAttention-2 / SDPA / flex). - """ - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - # --- QKV projection & RoPE (identical to upstream) -------------------- - query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - # --- KV cache update (Transformers >= 5.2 style) ---------------------- - if past_key_values is not None: - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) - - q_len = query_states.shape[2] - k_len = key_states.shape[2] - - # --- Prefill (Stem sparse attention) ---------------------------------- - if q_len > 1: - if attention_mask is not None: - _assert_no_padding_mask_for_stem(attention_mask, k_len) - - prefill_kwargs = { - "layer_idx": self.layer_idx, - "attn_forward_config": self.attn_forward_config, - } - backend = self.attn_forward_config.get("backend", "torch") - - # HPC kernels (both bf16 and fp8) handle GQA internally; - # only the pure-torch path needs explicit KV head repeat. - if backend == "hpc": - stem_key_states = key_states - stem_value_states = value_states - else: - stem_key_states = repeat_kv(key_states, self.num_key_value_groups) - stem_value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_output = stem_forward( - query_states, stem_key_states, stem_value_states, prefill_kwargs - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_weights = None - - # --- Decode (standard attention, mirrors upstream) --------------------- - else: - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights From 3461ff6ab491c23776bb557c43d6d8cb35b3c4d3 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Apr 2026 13:42:50 +0800 Subject: [PATCH 12/13] fix pre-commit --- angelslim/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index 6df56425..ff749ce8 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -30,4 +30,4 @@ from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 -from .zero3_io import * # noqa: F401 +from .zero3_io import * # noqa: F401 403 From e1b70964fe33fc3bfb93db2d0fdea6476103048f Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Apr 2026 13:44:03 +0800 Subject: [PATCH 13/13] fix pre-commit --- angelslim/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index ff749ce8..a12f24b5 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -30,4 +30,4 @@ from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 -from .zero3_io import * # noqa: F401 403 +from .zero3_io import * # noqa: F401 F403