Support QAT with deepspeed zero3 and gradient checkpoint#290
Merged
yghstill merged 13 commits intoTencent:mainfrom Apr 29, 2026
Merged
Support QAT with deepspeed zero3 and gradient checkpoint#290yghstill merged 13 commits intoTencent:mainfrom
yghstill merged 13 commits intoTencent:mainfrom
Conversation
yghstill
reviewed
Apr 28, 2026
| # ------------------------------------------------------------------ | ||
|
|
||
| @staticmethod | ||
| def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor: |
Collaborator
There was a problem hiding this comment.
蒸馏相关的单独放在一个py里,可以放在plugins里?
| labels = inputs["input_ids"].roll(shifts=-1, dims=-1) | ||
| labels[:, -1] = -100 | ||
| inputs["labels"] = labels.to(self.device) | ||
| # HF CausalLM models shift labels internally; feed labels == input_ids. |
Collaborator
Author
There was a problem hiding this comment.
label shift这个是之前的一个bug,causallm里面会自动将label做shift,所以不用提前shift一次。
另外新增了is_sft_data的开关,开启时会只计算assistant后面的loss,默认关闭。
| from .utils import print_with_rank # noqa: F401 | ||
| from .utils import rank0_print # noqa: F401 | ||
| from .utils import set_op_by_name # noqa: F401 | ||
| from .zero3_io import LinearizedMoeExperts # noqa: F401 |
Collaborator
There was a problem hiding this comment.
from .zero3_io import的内容都整合到一起?
yghstill
approved these changes
Apr 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds DeepSpeed ZeRO-3 support to the QAT pipeline so that QAT can scale to models that no longer fit on a single GPU (e.g. Qwen3-30B-A3B). It introduces a streaming model loader, ZeRO-3-aware fake quantization, MoE expert linearisation under
deepspeed.zero.Init, and a consolidated rank-0 save path. Single-GPU behavior is unchanged.Minor changes
Changes for deepspeed zero3
angelslim/utils/zero3_io.py(new, all ZeRO-3 helpers in one file)is_deepspeed_zero3_enabled,gathered_param(s)_if_zero3, ...)LinearizedMoeExperts+linearize_moe_experts_empty: duck-typed replacement for any HF "fused MoE expert" module, built insidedeepspeed.zero.Initso per-expertnn.Linearsare partitioned immediately with no copy from the fused tensorzero3_empty_model_from_pretrained+stream_load_weights: build an empty sharded model, then rank-0 reads safetensors shards one at a time and broadcasts each tensor into its possibly sharded target viaGatheredParameters(modifier_rank=0). Fused MoE keys (*.experts.gate_up_proj/down_proj) are sliced per expert into the linearised targetsstream_load_scales: loadweight_scale/input_scale/k_cache.scale/v_cache.scalefrom a previous PTQ "real" checkpoint into the matchingQuantLinearquantizerssave_via_model_save_func: gather sharded params into a rank-0 CPUstate_dictand patchmodel.state_dict()so the existingsave_func.save(...)writes a single consolidated checkpointpatch_deepspeed_duplicate_check: noop's DeepSpeed's_check_for_duplicatesso QAT scale parameters that share storage across module views are acceptedQuantizer (
angelslim/compressor/qat/modules/quantizer.py)weight_shapeso scale Parameters can be sized without inspecting the possibly sharded weightweight_scale_init_value/activation_scale_init_value(default1.0): used under ZeRO-3 to allocate scales from shape + init value, skipping any data-dependent initializationQuantLinear.forward: aligninput.dtypetoweight.dtypeand disable autocast aroundF.linearto avoid the dtype mismatch DeepSpeed'szero3_linear_wrapautocast can introducePlugin (
angelslim/compressor/qat/plugins/learnable_scale.py)from_ptq_ckpt_dir(required under ZeRO-3, since lazy-init via forward is impossible on sharded weights). After replacingnn.LinearwithQuantLinear, callsstream_load_scalesto fill scales from PTQquant_inplace: gathers weight + weight-quantizer parameters before the in-place fake quantQAT orchestrator (
angelslim/compressor/qat/qat.py)convert(): keeps model structure unchanged on every rank so collective gathers stay symmetric. All ranks materialise each layer once viagathered_param_if_zero3, then only rank 0 builds a temporaryQDQModule, runs the fp8/int kernel, and copies the result into a_rank0_state_dict. Per-rank CPU peak is bounded by one layer's weights, notworld_size × model_sizesave()hands the prebuilt dict tosave_via_model_save_func, which patchesstate_dictso the existing model-specificsave_func.save(...)writes a single consolidated checkpoint on rank 0Trainer (
angelslim/compressor/qat/trainers/end2end_trainer.py)compute_lossrewritten to combinelm_loss_weight * lm_loss + kd_loss_weight * kd_loss(loss_type); either weight may be 0kl/rkl/mse/kd/kl_top_K/r_kl_top_K, plus a newcakld(conf * reverse_kl + (1 - conf) * forward_kl, whereconfis the teacher's probability on the gold token). All KD losses operate only onlabels != -100positionslog()override injects per-step components (lm_loss/kd/<type>/kd/forward_kl/kd/backward_kl/total_loss) into the HF Trainer log dictOther small changes
models/base_model.py:from_pretrainedenters the ZeRO-3 path whenis_deepspeed_zero3_enabled(), otherwise falls back to the original HF path unchangeddata/text_dataset.py: jsonl SFT samples are supervised only on the last assistant turn; everything before becomes-100. Fix a previous double-shift bug: HF CausalLM does the next-token shift internally, solabels = input_ids.clone()utils/utils.py:set_op_by_namefalls back togetattr(mod_, str(idx))for containers that register children viasetattr(self, str(idx), ...)— the linearised MoE containerutils/config_parser.py: newQATTrainingConfigfieldsfrom_ptq_ckpt,lm_loss_weight(1.0),kd_loss_weight(0.0)engine.py: normalise device_map string"None"/"distributed"→ PythonNonetools/run.py: pre-constructSeq2SeqTrainingArguments(deepspeed=...)beforeprepare_model()so HF'sHfTrainerDeepSpeedConfigweak-ref is registered, which makesis_deepspeed_zero3_enabled()returnTrueduringfrom_pretrainedConfigs / scripts / docs
configs/qwen3/qat/fp8_static/learn_scale/{ds_config_zero3.json, qwen3-{4b,30b-a3b}_..._zero3.yaml}scripts/qat/run_qat_for_qwen_{4b,30b_a3b}_zero3.shdocs/source/features/quantization/qat_zero3.md— architecture, execution flow, config reference, limitations