Skip to content

Support QAT with deepspeed zero3 and gradient checkpoint#290

Merged
yghstill merged 13 commits intoTencent:mainfrom
yisunlp:deepspeed3
Apr 29, 2026
Merged

Support QAT with deepspeed zero3 and gradient checkpoint#290
yghstill merged 13 commits intoTencent:mainfrom
yisunlp:deepspeed3

Conversation

@yisunlp
Copy link
Copy Markdown
Collaborator

@yisunlp yisunlp commented Apr 27, 2026

This PR adds DeepSpeed ZeRO-3 support to the QAT pipeline so that QAT can scale to models that no longer fit on a single GPU (e.g. Qwen3-30B-A3B). It introduces a streaming model loader, ZeRO-3-aware fake quantization, MoE expert linearisation under deepspeed.zero.Init, and a consolidated rank-0 save path. Single-GPU behavior is unchanged.

Minor changes

  • add token-level cakld loss
  • add sft loss mask for text dataset

Changes for deepspeed zero3

angelslim/utils/zero3_io.py (new, all ZeRO-3 helpers in one file)

  • ZeRO-3 detection + parameter-gather contexts (is_deepspeed_zero3_enabled, gathered_param(s)_if_zero3, ...)
  • LinearizedMoeExperts + linearize_moe_experts_empty: duck-typed replacement for any HF "fused MoE expert" module, built inside deepspeed.zero.Init so per-expert nn.Linears are partitioned immediately with no copy from the fused tensor
  • zero3_empty_model_from_pretrained + stream_load_weights: build an empty sharded model, then rank-0 reads safetensors shards one at a time and broadcasts each tensor into its possibly sharded target via GatheredParameters(modifier_rank=0). Fused MoE keys (*.experts.gate_up_proj/down_proj) are sliced per expert into the linearised targets
  • stream_load_scales: load weight_scale / input_scale / k_cache.scale / v_cache.scale from a previous PTQ "real" checkpoint into the matching QuantLinear quantizers
  • save_via_model_save_func: gather sharded params into a rank-0 CPU state_dict and patch model.state_dict() so the existing save_func.save(...) writes a single consolidated checkpoint
  • patch_deepspeed_duplicate_check: noop's DeepSpeed's _check_for_duplicates so QAT scale parameters that share storage across module views are accepted

Quantizer (angelslim/compressor/qat/modules/quantizer.py)

  • Accept explicit weight_shape so scale Parameters can be sized without inspecting the possibly sharded weight
  • New weight_scale_init_value / activation_scale_init_value (default 1.0): used under ZeRO-3 to allocate scales from shape + init value, skipping any data-dependent initialization
  • QuantLinear.forward: align input.dtype to weight.dtype and disable autocast around F.linear to avoid the dtype mismatch DeepSpeed's zero3_linear_wrap autocast can introduce

Plugin (angelslim/compressor/qat/plugins/learnable_scale.py)

  • New from_ptq_ckpt_dir (required under ZeRO-3, since lazy-init via forward is impossible on sharded weights). After replacing nn.Linear with QuantLinear, calls stream_load_scales to fill scales from PTQ
  • quant_inplace: gathers weight + weight-quantizer parameters before the in-place fake quant

QAT orchestrator (angelslim/compressor/qat/qat.py)

  • ZeRO-3 convert(): keeps model structure unchanged on every rank so collective gathers stay symmetric. All ranks materialise each layer once via gathered_param_if_zero3, then only rank 0 builds a temporary QDQModule, runs the fp8/int kernel, and copies the result into a _rank0_state_dict. Per-rank CPU peak is bounded by one layer's weights, not world_size × model_size
  • save() hands the prebuilt dict to save_via_model_save_func, which patches state_dict so the existing model-specific save_func.save(...) writes a single consolidated checkpoint on rank 0

Trainer (angelslim/compressor/qat/trainers/end2end_trainer.py)

  • compute_loss rewritten to combine lm_loss_weight * lm_loss + kd_loss_weight * kd_loss(loss_type); either weight may be 0
  • KD loss types: kl / rkl / mse / kd / kl_top_K / r_kl_top_K, plus a new cakld (conf * reverse_kl + (1 - conf) * forward_kl, where conf is the teacher's probability on the gold token). All KD losses operate only on labels != -100 positions
  • Optimizer collection uses id-based deduplication so tied scale Parameters across MoE expert views don't crash DeepSpeed's duplicate check
  • New log() override injects per-step components (lm_loss / kd/<type> / kd/forward_kl / kd/backward_kl / total_loss) into the HF Trainer log dict

Other small changes

  • models/base_model.py: from_pretrained enters the ZeRO-3 path when is_deepspeed_zero3_enabled(), otherwise falls back to the original HF path unchanged
  • data/text_dataset.py: jsonl SFT samples are supervised only on the last assistant turn; everything before becomes -100. Fix a previous double-shift bug: HF CausalLM does the next-token shift internally, so labels = input_ids.clone()
  • utils/utils.py: set_op_by_name falls back to getattr(mod_, str(idx)) for containers that register children via setattr(self, str(idx), ...) — the linearised MoE container
  • utils/config_parser.py: new QATTrainingConfig fields from_ptq_ckpt, lm_loss_weight (1.0), kd_loss_weight (0.0)
  • engine.py: normalise device_map string "None" / "distributed" → Python None
  • tools/run.py: pre-construct Seq2SeqTrainingArguments(deepspeed=...) before prepare_model() so HF's HfTrainerDeepSpeedConfig weak-ref is registered, which makes is_deepspeed_zero3_enabled() return True during from_pretrained

Configs / scripts / docs

  • configs/qwen3/qat/fp8_static/learn_scale/{ds_config_zero3.json, qwen3-{4b,30b-a3b}_..._zero3.yaml}
  • scripts/qat/run_qat_for_qwen_{4b,30b_a3b}_zero3.sh
  • docs/source/features/quantization/qat_zero3.md — architecture, execution flow, config reference, limitations

# ------------------------------------------------------------------

@staticmethod
def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

蒸馏相关的单独放在一个py里,可以放在plugins里?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

labels = inputs["input_ids"].roll(shifts=-1, dims=-1)
labels[:, -1] = -100
inputs["labels"] = labels.to(self.device)
# HF CausalLM models shift labels internally; feed labels == input_ids.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对PTQ是否有影响

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label shift这个是之前的一个bug,causallm里面会自动将label做shift,所以不用提前shift一次。
另外新增了is_sft_data的开关,开启时会只计算assistant后面的loss,默认关闭。

Comment thread angelslim/utils/__init__.py Outdated
from .utils import print_with_rank # noqa: F401
from .utils import rank0_print # noqa: F401
from .utils import set_op_by_name # noqa: F401
from .zero3_io import LinearizedMoeExperts # noqa: F401
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from .zero3_io import的内容都整合到一起?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@yghstill yghstill merged commit 7e4ffbc into Tencent:main Apr 29, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants