Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 154 additions & 22 deletions angelslim/compressor/qat/modules/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, config, quant_info, x=None, is_act=False, resume=False, num_h
self._apply_settings(info, rewrite_conf)
self._set_quant_range()
self._init_quant_params(x)
self._init_lwc_params(x, config)

def _apply_settings(self, info, rewrite_conf):
if rewrite_conf:
Expand Down Expand Up @@ -130,6 +131,39 @@ def _init_quant_params(self, x):
)
self._set_quant_parameters(scale, zp.round())

def _init_lwc_params(self, x, config):
lwc_cfg = config.get("lwc", {})
if isinstance(lwc_cfg, dict):
self.lwc = (not self.is_act) and bool(lwc_cfg.get("enable_lwc", False))
self.lwc_init_value = float(lwc_cfg.get("lwc_init_value", 4.0))
else:
self.lwc = (not self.is_act) and bool(lwc_cfg)
self.lwc_init_value = 4.0

if self.lwc:
if x.dim() != 2:
x_for_shape = x.flatten(1)
else:
x_for_shape = x
out_dim, in_dim = x_for_shape.shape
if self.granularity == "per-group":
if not self.group_size or self.group_size <= 0:
raise ValueError("per-group quantization requires positive group_size.")
assert in_dim % self.group_size == 0
n_groups = in_dim // self.group_size
dim1 = out_dim * n_groups
elif self.granularity == "per-channel":
dim1 = out_dim
else:
dim1 = 1

init = (
torch.ones((dim1, 1), device=x.device, dtype=torch.float32) * self.lwc_init_value
)
self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True)
self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True)
self.sigmoid = nn.Sigmoid()

def _compute_scales(self, x, granularity="per-tensor", group_size=-1):
if granularity == "per-tensor":
s = torch.clamp(torch.max(torch.abs(x.flatten())), min=1e-8)
Expand Down Expand Up @@ -309,47 +343,145 @@ def _expand(t, target_shape):

elif self.granularity == "per-token":
# scale: [n_tokens, 1] -> [n_tokens, in_features] then reshape to x.shape
init_shape = x.shape
rx = x.reshape(-1, x.shape[-1])
scale = _expand(scale, rx.shape).reshape(init_shape)
zero_point = (
_expand(zero_point, rx.shape).reshape(init_shape)
if zero_point is not None
else None
)
scale = _expand(scale, x.shape)
zero_point = _expand(zero_point, x.shape) if zero_point is not None else None

elif self.granularity == "per-head":
if self.num_heads <= 0:
raise ValueError("num_heads must be set for per-head granularity.")
if x.shape[-1] % self.num_heads != 0:
raise ValueError(
f"last dim ({x.shape[-1]}) must be divisible by num_heads ({self.num_heads})"
)
head_dim = x.shape[-1] // self.num_heads
head_shape = (*x.shape[:-1], self.num_heads, head_dim)

def _expand_per_head(t):
if t is None:
return None
# Broadcast one scale per head across that head's contiguous feature slice.
view_shape = (1,) * (x.dim() - 1) + (self.num_heads, 1)
return t.reshape(view_shape).expand(head_shape).reshape(x.shape)

scale = _expand_per_head(scale)
zero_point = _expand_per_head(zero_point)

return scale, zero_point

def fake_quant(self, x):
scale = clamp_ste(self.scale, 1e-4, 1e4)
round_zero_point = (
None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax)
)
scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x)
def _expand_scale_zp_lwc(self, scale, zero_point, x):
def _expand(t, target_shape):
if t is None:
return None
return t.expand(target_shape)

if self.granularity == "per-channel":
scale = _expand(scale, x.shape)
zero_point = _expand(zero_point, x.shape)

elif self.granularity == "per-group":
group_size = self.group_size
scale = scale.unsqueeze(-1).expand(*scale.shape, group_size).reshape(x.shape)
if zero_point is not None:
zero_point = (
zero_point.unsqueeze(-1).expand(*zero_point.shape, group_size).reshape(x.shape)
)

return scale, zero_point

def _lwc_fake_quant_weight(self, x: torch.Tensor) -> torch.Tensor:
# Weight-only LWC path (OmniQuant-style):
# compute scale/zp from (possibly grouped) xmin/xmax
# with learnable bound factors, then quantize with STE.
x_dtype = x.dtype
x_work = x
if x_work.dim() != 2:
x_work = x_work.flatten(1)

if self.granularity == "per-group":
out_dim, in_dim = x_work.shape
x_reduce = x_work.reshape(out_dim, in_dim // self.group_size, self.group_size)
xmin = x_reduce.amin(dim=-1)
xmax = x_reduce.amax(dim=-1)
elif self.granularity == "per-channel":
xmin = x_work.amin(dim=-1, keepdim=True)
xmax = x_work.amax(dim=-1, keepdim=True)
else:
# per-tensor (default)
xmin = x_work.amin().view(1, 1)
xmax = x_work.amax().view(1, 1)

xmax = self.sigmoid(self.clip_factor_w_max).reshape_as(xmax) * xmax
xmin = self.sigmoid(self.clip_factor_w_min).reshape_as(xmin) * xmin
if self.is_sym:
abs_max = torch.max(xmax.abs(), xmin.abs())
scale = (abs_max / self.qmax).to(dtype=x_work.dtype)
round_zero_point = None
else:
range_ = xmax - xmin
scale = (range_ / (self.qmax - self.qmin)).to(dtype=x_work.dtype)
zero_point = (-xmin) / range_ * (self.qmax - self.qmin) + self.qmin
round_zero_point = clamp_ste(round_ste(zero_point), self.qmin, self.qmax)
scale, round_zero_point = self._expand_scale_zp_lwc(scale, round_zero_point, x_work)
x_dequant = self._fake_quant_with_params(x_work, scale, round_zero_point)
return x_dequant.to(dtype=x_dtype).reshape(x.shape)

def _w4a8_fp8_ste_from_dequant(
self, x_dequant: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
fp8_scale = scale.max() * self.qmax / FP8_E4M3_QMAX
weight_fp8 = x_dequant / fp8_scale
weight_fp8_q = weight_fp8.clamp(FP8_E4M3_QMIN, FP8_E4M3_QMAX).to(torch.float8_e4m3fn)
weight_fp8_q = (weight_fp8_q.to(torch.bfloat16) - weight_fp8).detach() + weight_fp8
return weight_fp8_q * fp8_scale

def _fake_quant_with_params(
self,
x: torch.Tensor,
scale: torch.Tensor,
round_zero_point: torch.Tensor | None,
) -> torch.Tensor:
scale = clamp_ste(scale, 1e-4, 1e4)
if round_zero_point is not None:
round_zero_point = clamp_ste(round_zero_point, self.qmin, self.qmax)

if self.is_w4a8_fp8:
x_int4 = round_ste(x / scale)
x_int4 = clamp_ste(x_int4, self.qmin, self.qmax).mul(scale)
fp8_scale = scale.max() * self.qmax / FP8_E4M3_QMAX
weight_fp8 = (x_int4 / fp8_scale).clamp(-448, 448).to(torch.float8_e4m3fn)
return weight_fp8.to(torch.bfloat16) * fp8_scale
x_int4 = clamp_ste(x_int4, self.qmin, self.qmax)
x_dequant = x_int4.mul(scale)
return self._w4a8_fp8_ste_from_dequant(x_dequant, scale)

if self.dtype == "fp8":
weight_fp8 = (x / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
return weight_fp8.to(torch.bfloat16) * scale
weight_fp8 = x / scale
weight_fp8 = clamp_ste(weight_fp8, FP8_E4M3_QMIN, FP8_E4M3_QMAX)
weight_fp8_q = weight_fp8.to(torch.float8_e4m3fn).to(torch.bfloat16)
weight_fp8_q = (weight_fp8_q - weight_fp8).detach() + weight_fp8
return weight_fp8_q * scale

x_int = round_ste(x / scale)
if round_zero_point is not None:
x_int = x_int.add(round_zero_point)
x_int = clamp_ste(x_int, self.qmin, self.qmax)
x_dequant = x_int
if round_zero_point is not None:
x_int = x_int.sub(round_zero_point)
return x_int.mul(scale)
x_dequant = x_dequant.sub(round_zero_point)
x_dequant = x_dequant.mul(scale)
return x_dequant

def fake_quant(self, x):
scale = clamp_ste(self.scale, 1e-4, 1e4)
round_zero_point = (
None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax)
)
scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x)
return self._fake_quant_with_params(x, scale, round_zero_point)

def forward(self, x: torch.Tensor):
if self.bits >= 16:
return x

if self.lwc:
return self._lwc_fake_quant_weight(x)

if self.is_act and not self.dynamic and not self.init:
self._lazy_init(x)
return x
Expand Down
22 changes: 16 additions & 6 deletions angelslim/compressor/qat/plugins/learnable_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, **

# Parse learnable config (boolean switches for each parameter group)
learnable_cfg = self.config.get("learnable", {})
self.learn_act_scale = learnable_cfg.get("act_scale", True)
self.learn_weight_scale = learnable_cfg.get("weight_scale", False)
self.learn_act_scale = learnable_cfg.get("act_scale", False)
self.learn_weight_scale = learnable_cfg.get("weight_scale", True)
self.learn_lwc = learnable_cfg.get("lwc", False)
self.learn_kv_scale = learnable_cfg.get("kv_scale", False)
self.learn_norm = learnable_cfg.get("norm", False)

Expand Down Expand Up @@ -100,6 +101,7 @@ def _apply_learn_strategy(self):
act_scale=self.learn_act_scale,
weight_scale=self.learn_weight_scale,
kv_scale=self.learn_kv_scale,
lwc=self.learn_lwc,
)

if self.learn_norm:
Expand All @@ -109,7 +111,8 @@ def _apply_learn_strategy(self):
f"act_scale={self.learn_act_scale}, "
f"weight_scale={self.learn_weight_scale}, "
f"kv_scale={self.learn_kv_scale}, "
f"norm={self.learn_norm}"
f"norm={self.learn_norm}",
f"lwc={self.learn_lwc}",
)
print_info(
f"Learnable config ({learnable_summary}): "
Expand Down Expand Up @@ -205,15 +208,15 @@ def quant_parameters(model):
def set_weight_parameters(model, requires_grad):
params = []
for n, m in model.named_parameters():
if n.find("weight") > -1 and not (n.find("scale") > -1 or n.find("zero_point") > -1):
if n.endswith("weight") and not (n.find("scale") > -1 or n.find("zero_point") > -1):
m.requires_grad = requires_grad
return iter(params)


def weight_parameters(model):
params = []
for n, m in model.named_parameters():
if n.find("weight") > -1 and not (n.find("scale") > -1 or n.find("zero_point") > -1):
if n.endswith("weight") and not (n.find("scale") > -1 or n.find("zero_point") > -1):
params.append(m)
return iter(params)

Expand All @@ -226,7 +229,9 @@ def trainable_parameters(model):
return iter(params)


def _set_learnable_parameters(model, act_scale=False, weight_scale=False, kv_scale=False):
def _set_learnable_parameters(
model, act_scale=False, weight_scale=False, kv_scale=False, lwc=False
):
_KV_SUFFIXES = ("k_proj", "v_proj")

for name, module in model.named_modules():
Expand All @@ -243,6 +248,11 @@ def _set_learnable_parameters(model, act_scale=False, weight_scale=False, kv_sca
if "scale" in pname or "zero_point" in pname:
param.requires_grad = True

if lwc and hasattr(module, "weight_quantizer"):
for pname, param in module.weight_quantizer.named_parameters():
if "clip_factor_w_max" in pname or "clip_factor_w_min" in pname:
param.requires_grad = True

if kv_scale and hasattr(module, "qkv_quantizer"):
suffix = name.rsplit(".", 1)[-1] if "." in name else name
if suffix in _KV_SUFFIXES:
Expand Down
Loading
Loading