Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ eval/
*_ckpt*/
output/
outputs/
output*/
logs*/
outs/
wandb/
tools/results/
Expand Down
138 changes: 122 additions & 16 deletions angelslim/compressor/qat/modules/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch.nn as nn
import torch.nn.functional as F

from ....utils import is_deepspeed_zero3_enabled, is_zero3_param

FP8_E4M3_QMIN = -448
FP8_E4M3_QMAX = 448

Expand Down Expand Up @@ -49,10 +51,30 @@ def _parse_bits_and_dtype(qtype_str):


class Quantizer(nn.Module):
def __init__(self, config, quant_info, x=None, is_act=False, resume=False, num_heads=-1):
def __init__(
self,
config,
quant_info,
x=None,
is_act=False,
resume=False,
num_heads=-1,
weight_shape=None,
):
super().__init__()
self.is_act = is_act
self.num_heads = num_heads
# ``weight_shape`` lets the caller pre-declare the (out_features,
# in_features) of the parent Linear so we can size weight-side
# quantizer Parameters without ever touching the (possibly ZeRO-3
# sharded) weight tensor.
self.weight_shape = (
(int(weight_shape[0]), int(weight_shape[1])) if weight_shape is not None else None
)
# Configurable initial values used when ZeRO-3 is active and we
# cannot depend on the weight data.
self.weight_scale_init_value = float(config.get("weight_scale_init_value", 1.0))
self.activation_scale_init_value = float(config.get("activation_scale_init_value", 1.0))
info = quant_info.quant_algo_info["w"]
self.group_size = quant_info.quant_algo_info.get("w_group_size", -1)
rewrite_conf = config.get("weight", {})
Expand Down Expand Up @@ -117,8 +139,21 @@ def _init_quant_params(self, x):
self.scale = self.zero_point = None
if self.resume:
self.init = True
zp = torch.empty(1) if not self.is_sym else None
self._set_quant_parameters(torch.empty(1), zp)
init_val = self.activation_scale_init_value
scale = torch.full((1,), init_val, dtype=torch.float32)
zp = torch.zeros(1, dtype=torch.float32) if not self.is_sym else None
self._set_quant_parameters(scale, zp)
return

# Weight-side path. If we cannot use ``x`` (ZeRO-3 sharded,
# meta, or simply not provided), allocate Parameters by shape
# and ``weight_scale_init_value``.
if self._needs_external_weight_init(x):
shape = self._weight_scale_shape_from_meta()
init_val = self.weight_scale_init_value
scale = torch.full(shape, init_val, dtype=torch.float32)
zp = torch.zeros(shape, dtype=torch.float32) if not self.is_sym else None
self._set_quant_parameters(scale, zp)
return

if self.is_sym:
Expand All @@ -131,6 +166,52 @@ def _init_quant_params(self, x):
)
self._set_quant_parameters(scale, zp.round())

def _needs_external_weight_init(self, x):
"""True when weight-side init must skip data-dependent computation
and instead allocate Parameters from shape + init_value.

Triggered by:
* DeepSpeed ZeRO-3 active (HF integration registered)
* ``x`` is a ZeRO-3 sharded Parameter
* ``x`` is None / on meta device / empty
"""
if is_deepspeed_zero3_enabled():
return True
if x is None:
return True
if is_zero3_param(x):
return True
if hasattr(x, "device") and x.device.type == "meta":
return True
if hasattr(x, "numel") and x.numel() == 0:
return True
return False

def _weight_2d_shape(self):
"""Resolve (out_features, in_features) for the underlying Linear.
Callers must have passed ``weight_shape`` via ``QuantLinear``."""
if self.weight_shape is not None:
return self.weight_shape
raise RuntimeError(
"Quantizer needs ``weight_shape`` to size weight scale without a "
"concrete tensor (set in QuantLinear.__init__)."
)

def _weight_scale_shape_from_meta(self):
out_dim, in_dim = self._weight_2d_shape()
if self.granularity == "per-channel":
return (out_dim, 1)
if self.granularity == "per-group":
if not self.group_size or self.group_size <= 0:
raise ValueError("per-group quantization requires positive group_size.")
if in_dim % self.group_size != 0:
raise ValueError(
f"dim 1 ({in_dim}) not divisible by group_size ({self.group_size})"
)
return (out_dim, in_dim // self.group_size)
# per-tensor and any reduce-to-scalar variant
return (1,)

def _init_lwc_params(self, x, config):
lwc_cfg = config.get("lwc", {})
if isinstance(lwc_cfg, dict):
Expand All @@ -141,11 +222,18 @@ def _init_lwc_params(self, x, config):
self.lwc_init_value = 4.0

if self.lwc:
if x.dim() != 2:
x_for_shape = x.flatten(1)
# Resolve (out_dim, in_dim) without depending on ``x`` data.
if self._needs_external_weight_init(x):
out_dim, in_dim = self._weight_2d_shape()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
x_for_shape = x
out_dim, in_dim = x_for_shape.shape
if x.dim() != 2:
x_for_shape = x.flatten(1)
else:
x_for_shape = x
out_dim, in_dim = x_for_shape.shape
device = x.device

if self.granularity == "per-group":
if not self.group_size or self.group_size <= 0:
raise ValueError("per-group quantization requires positive group_size.")
Expand All @@ -157,9 +245,7 @@ def _init_lwc_params(self, x, config):
else:
dim1 = 1

init = (
torch.ones((dim1, 1), device=x.device, dtype=torch.float32) * self.lwc_init_value
)
init = torch.ones((dim1, 1), device=device, dtype=torch.float32) * self.lwc_init_value
self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True)
self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True)
self.sigmoid = nn.Sigmoid()
Expand Down Expand Up @@ -473,7 +559,14 @@ def fake_quant(self, x):
None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax)
)
scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x)
return self._fake_quant_with_params(x, scale, round_zero_point)
out = self._fake_quant_with_params(x, scale, round_zero_point)
# Scale is kept in fp32 for numerical stability, but multiplying by
# a bf16/fp16 activation upcasts the result. Cast back to the input
# dtype so downstream F.linear / DeepSpeed autocast wrappers see a
# consistent dtype.
if out.dtype != x.dtype:
out = out.to(x.dtype)
return out

def forward(self, x: torch.Tensor):
if self.bits >= 16:
Expand Down Expand Up @@ -516,8 +609,18 @@ def __init__(
self.register_parameter("bias", org_module.bias)
self.use_weight_quant = use_weight_quant
self.use_act_quant = use_act_quant
# Under ZeRO-3 the weight Parameter ``org_module.weight`` may be a
# zero-numel shard. Pass an explicit (out, in) shape so the weight
# quantizer can size its scale Parameter from the Linear shape
# rather than inspecting the (possibly sharded) tensor.
weight_shape = (org_module.out_features, org_module.in_features)
if self.use_weight_quant:
self.weight_quantizer = Quantizer(config, quant_info, x=org_module.weight)
self.weight_quantizer = Quantizer(
config,
quant_info,
x=org_module.weight,
weight_shape=weight_shape,
)
if self.use_act_quant:
self.act_quantizer = Quantizer(config, quant_info, is_act=True, resume=resume)

Expand All @@ -531,13 +634,16 @@ def __init__(
)

def forward(self, input: torch.Tensor):
if input.shape[0] == 0:
return self.fwd_func(input, self.weight, self.bias)

weight = self.weight_quantizer(self.weight) if self.use_weight_quant else self.weight
if self.use_act_quant:
input = self.act_quantizer(input)
output = self.fwd_func(input, weight, self.bias)
# Defensive dtype alignment: upstream (DeepSpeed ZeRO-3 / HF
# autocast) may have cast ``input`` to fp16 even though we run in
# bf16. Align to the (fake-quantised) weight dtype so F.linear
# stays consistent.
output = self.fwd_func(
input.to(self.weight.dtype), weight.to(self.weight.dtype), self.bias
)
if self.use_qkv_quant:
output = self.qkv_quantizer(output)
return output
Expand Down
117 changes: 117 additions & 0 deletions angelslim/compressor/qat/plugins/distill_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2025 Tencent Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F


class DistillLoss:
def __init__(self, loss_type="kl", loss_topk=None, kd_temperature=1.0):
self.loss_type = str(loss_type).lower()
self.loss_topk = loss_topk
self.kd_temperature = float(kd_temperature)

@staticmethod
def _kl_per_token(log_p_src: torch.Tensor, p_tgt: torch.Tensor) -> torch.Tensor:
"""Per-token KL(tgt || src) (shape [N])."""
return F.kl_div(log_p_src, p_tgt, reduction="none").sum(dim=-1)

def compute(self, student_logits, teacher_logits, labels):
"""Return per-token KD losses computed only on valid labels.

The returned dict always contains ``loss``, ``forward_kl`` and
``backward_kl`` so callers can log diagnostics for every KD variant.
"""
flat_mask = (labels != -100).reshape(-1)
if flat_mask.sum() == 0:
zero = student_logits.new_zeros(())
return {"loss": zero, "forward_kl": zero, "backward_kl": zero}

# Flatten to [N, V] and keep only valid tokens.
s_flat = student_logits.flatten(0, -2)[flat_mask]
t_flat = teacher_logits.flatten(0, -2)[flat_mask]
valid_labels = labels.reshape(-1)[flat_mask]

# Diagnostic KL (always computed at T=1 on valid tokens).
s_logp = F.log_softmax(s_flat, dim=-1)
t_logp = F.log_softmax(t_flat, dim=-1)
s_p = s_logp.exp()
t_p = t_logp.exp()
forward_kl = self._kl_per_token(s_logp, t_p).mean()
backward_kl = self._kl_per_token(t_logp, s_p).mean()

# Main KD loss according to loss_type.
if self.loss_type == "kl":
kd = forward_kl
elif self.loss_type == "rkl":
kd = backward_kl
elif self.loss_type == "mse":
kd = F.mse_loss(s_flat, t_flat)
elif self.loss_type == "kd":
# Legacy "kd": temperature-scaled forward KL. The outer trainer
# combines this value with the LM loss by configured weights.
temperature = max(self.kd_temperature, 1e-6)
kd = self._kl_per_token(
F.log_softmax(s_flat / temperature, dim=-1),
F.softmax(t_flat / temperature, dim=-1),
).mean() * (temperature * temperature)
elif self.loss_type == "cakld":
# Contextual Asymmetric KL-divergence: per-token mixing of
# forward / reverse KL by teacher's confidence on the label.
per_tok_fkl = self._kl_per_token(s_logp, t_p)
per_tok_bkl = self._kl_per_token(t_logp, s_p)
conf = torch.gather(t_p, dim=-1, index=valid_labels.unsqueeze(-1)).squeeze(-1)
kd = (conf * per_tok_bkl + (1.0 - conf) * per_tok_fkl).mean()
elif self._is_reverse_topk_loss(self.loss_type):
topk = self._resolve_topk(self.loss_type, s_flat.size(-1))
top_s, idx = s_flat.topk(topk, dim=-1, sorted=False)
top_t = t_flat.gather(-1, idx)
kd = self._kl_per_token(
F.log_softmax(top_t, dim=-1),
F.softmax(top_s, dim=-1),
).mean()
elif self._is_forward_topk_loss(self.loss_type):
topk = self._resolve_topk(self.loss_type, t_flat.size(-1))
top_t, idx = t_flat.topk(topk, dim=-1, sorted=False)
top_s = s_flat.gather(-1, idx)
kd = self._kl_per_token(
F.log_softmax(top_s, dim=-1),
F.softmax(top_t, dim=-1),
).mean()
else:
raise ValueError(
f"Unsupported QAT kd loss_type: {self.loss_type}. "
"Valid: kl, rkl, mse, kd, cakld, kl_top[_K], r_kl_top[_K]."
)

return {"loss": kd, "forward_kl": forward_kl, "backward_kl": backward_kl}

@staticmethod
def _is_forward_topk_loss(loss_type):
return loss_type.startswith("kl_top")

@staticmethod
def _is_reverse_topk_loss(loss_type):
return loss_type.startswith("r_kl_top") or loss_type.startswith("rkl_top")

def _resolve_topk(self, loss_type, vocab_size):
topk = self.loss_topk
if topk is None and "_top_" in loss_type:
topk = int(loss_type.rsplit("_", 1)[-1])
if topk is None:
topk = 1000
topk = int(topk)
if topk <= 0:
raise ValueError(f"loss_topk must be positive, got: {topk}")
return min(topk, vocab_size)
Loading
Loading