Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,135 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
logger.warning("No linear modules were found in your model for quantization.")

return model


class HqqQuantize:
"""HQQ quantization operation for the new weight loading flow."""

def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict,
full_layer_name=None,
model=None,
**kwargs,
):
from hqq.core.quantize import HQQLinear

from ..quantizers.quantizers_utils import get_module_from_name

# input_dict has {param_name: [tensor]} for the weight
value = list(input_dict.values())[0]
value = value[0] if isinstance(value, list) else value

# full_layer_name is e.g. "model.layers.0.self_attn.q_proj.weight"
module_name = full_layer_name.rsplit(".", 1)[0]
module, _ = get_module_from_name(model, full_layer_name)

# Load weight into the nn.Linear module
module.weight = torch.nn.Parameter(value, requires_grad=False)

# Get the quant_config that was set in _process_model_before_weight_loading
quant_config = getattr(module, "quant_config", None)
if quant_config is None:
# Module is skipped from quantization, just return the weight as-is
return {full_layer_name: value}

# Determine target device and compute dtype
target_device = value.device
compute_dtype = self.hf_quantizer.dtype

# Create HQQLinear from the nn.Linear
hqq_layer = HQQLinear(
module,
quant_config=quant_config,
compute_dtype=compute_dtype,
device=target_device,
del_orig=True,
)

if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)

if self.hf_quantizer.using_multi_gpu:
Comment on lines +161 to +182
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we didn't have to do that here

hqq_layer = self.hf_quantizer._patch_layer_for_multigpu(hqq_layer)

# Replace the module in the model
parent_module_name, _, child_name = module_name.rpartition(".")
parent_module = model.get_submodule(parent_module_name) if parent_module_name else model
setattr(parent_module, child_name, hqq_layer)

# Mark as loaded so it's not reported as missing
missing_keys = kwargs.get("missing_keys")
if missing_keys is not None:
missing_keys.discard(full_layer_name)

# Return empty dict so the loading code doesn't try to set params
return {}


class HqqDeserialize:
"""Deserialize HQQ pre-quantized weights into an HQQLinear module."""

def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict,
full_layer_name=None,
model=None,
**kwargs,
):
from hqq.core.quantize import HQQLinear

# Unwrap list values
state_dict = {}
for key, value in input_dict.items():
state_dict[key] = value[0] if isinstance(value, list) else value

# If W_q is not present, this is not an HQQ-quantized layer — pass through
if "W_q" not in state_dict:
return input_dict

# full_layer_name is e.g. "model.layers.0.self_attn.v_proj.weight"
# (target pattern "weight" appended to module path)
module_name = full_layer_name.rsplit(".", 1)[0]

parent_name, _, child_name = module_name.rpartition(".")
parent = model.get_submodule(parent_name) if parent_name else model

# Create empty HQQLinear
hqq_layer = HQQLinear(
None,
None,
compute_dtype=self.hf_quantizer.dtype or torch.float16,
device="cpu",
initialize=False,
)

# Make W_q an nn.Parameter as HQQ expects
if "W_q" in state_dict:
state_dict["W_q"] = torch.nn.Parameter(state_dict["W_q"], requires_grad=False)

hqq_layer.load_state_dict(state_dict)

if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)

if self.hf_quantizer.using_multi_gpu:
hqq_layer = self.hf_quantizer._patch_layer_for_multigpu(hqq_layer)

setattr(parent, child_name, hqq_layer)

# Mark weight and bias as loaded
missing_keys = kwargs.get("missing_keys")
if missing_keys is not None:
missing_keys.discard(full_layer_name)
# Also discard bias since HQQLinear handles it internally
bias_key = module_name + ".bias"
missing_keys.discard(bias_key)

return {}
184 changes: 177 additions & 7 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ def __init__(self, quantization_config, **kwargs):
)
super().__init__(quantization_config, **kwargs)
self.dtype = None
self.device_map = None
self.using_multi_gpu = False
# Keys that are serialized specifically by hqq
self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"}

def update_dtype(self, dtype):
if dtype is not None:
self.dtype = dtype
return dtype
Comment on lines +67 to +70
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need that, the tensors should be in the right dtype so we should be able to access that directly


def validate_environment(self, *args, **kwargs):
if self.dtype is None:
if "dtype" in kwargs:
Expand All @@ -72,6 +78,7 @@ def validate_environment(self, *args, **kwargs):
logger.info("Setting dtype to torch.float32 as the default value since it was not specified.")
Comment on lines 73 to 78
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove that


device_map = kwargs.get("device_map")
self.device_map = device_map
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
raise ValueError(
Expand Down Expand Up @@ -144,10 +151,16 @@ def validate_environment(self, *args, **kwargs):
# return list(new_keys)

def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
module, _ = get_module_from_name(model, param_name)
# Since we do not prepare the modules in advance, we need every param of the Linear layer to go through
# `create_quantized_param`, even when `self.is_quantized == True`
return isinstance(module, torch.nn.Linear)
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"

def get_quantize_ops(self):
from ..integrations.hqq import HqqQuantize

return HqqQuantize(self)

def get_weight_conversions(self):
return []
Comment on lines +162 to +163
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use deserialize


# TODO: to remove
# def create_quantized_param(
Expand Down Expand Up @@ -232,6 +245,47 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **

# setattr(parent_module, node, hqq_layer)

def _setup_missing_key_filters(self, model, checkpoint_files):
"""Scan checkpoint files to find HQQ-quantized modules.

For those modules:
1. Suppress their .weight missing key warnings in the load report.
2. Replace their weight parameter with a scalar meta tensor so that
``_move_missing_keys_from_meta_to_device`` does not allocate
full-size fp16 tensors on GPU (which would cause OOM).
"""
import re

from safetensors import safe_open

quantized_modules = set()
Comment on lines +248 to +261
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not do these types of changes ?

for ckpt_file in checkpoint_files:
if ckpt_file.endswith(".safetensors"):
with safe_open(ckpt_file, framework="pt") as f:
for k in f.keys():
if k.endswith(".W_q"):
quantized_modules.add(k[: -len(".W_q")])
else:
state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True)
for k in state_dict:
if k.endswith(".W_q"):
quantized_modules.add(k[: -len(".W_q")])

if quantized_modules:
# Build regex that matches only .weight keys of quantized modules
escaped = [re.escape(m) + r"\.weight" for m in quantized_modules]
existing = model._keys_to_ignore_on_load_missing or []
model._keys_to_ignore_on_load_missing = existing + escaped

# Replace weight params with scalar meta tensors to avoid GPU allocation
for module_name in quantized_modules:
try:
module = model.get_submodule(module_name)
except AttributeError:
continue
if hasattr(module, "weight") and module.weight is not None:
module.weight = torch.nn.Parameter(torch.empty(0, device="meta"), requires_grad=False)

def _patch_layer_for_multigpu(self, hqq_layer):
def forward_with_device(self, x):
out = torch.matmul(x.to(self.device), self.dequantize().t())
Expand All @@ -245,17 +299,133 @@ def forward_with_device(self, x):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
checkpoint_files=None,
**kwargs,
):
# Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
# prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)
if self.pre_quantized:
# Store checkpoint files for loading in _process_model_after_weight_loading
self._checkpoint_files = checkpoint_files

# Suppress noisy load report: HQQ checkpoint keys (W_q, scale, etc.) are
# "unexpected" and nn.Linear .weight keys are "missing" from the standard
# loading perspective, but _load_hqq_from_checkpoint handles them.
hqq_keys = HQQLinear(None, None).state_dict_keys()
ignore_unexpected = [rf"\.{k}$" for k in hqq_keys]
existing = model._keys_to_ignore_on_load_unexpected or []
model._keys_to_ignore_on_load_unexpected = existing + ignore_unexpected

# For missing keys: scan checkpoint to find which modules have W_q (are HQQ-quantized),
# and suppress only their .weight keys. Also replace their weight with a scalar meta
# tensor to prevent _move_missing_keys_from_meta_to_device from allocating full-size
# tensors on GPU (which would cause OOM for large models).
self._setup_missing_key_filters(model, checkpoint_files)
else:
# Add the corresponding quant_config to each valid module for on-the-fly quantization.
# prepare_for_hqq_linear() also sets the right quantization config inside the model
# (model.config.quantization_config) and the layers (hqq_layer.quant_config)
model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
if self.pre_quantized:
self._load_hqq_from_checkpoint(model)
setattr(model, "is_hqq_quantized", True)
setattr(model, "is_hqq_serializable", self.is_serializable())
return model

def _load_hqq_from_checkpoint(self, model: "PreTrainedModel"):
"""Load pre-quantized HQQ weights directly from checkpoint files."""
from collections import defaultdict

from safetensors import safe_open

from ..integrations.hqq import autoname_modules, name_to_linear_tag

# Determine target device from stored device_map
device_map = getattr(self, "device_map", None)
if isinstance(device_map, dict):
# Use the first non-cpu device from the map (values can be str, int, or torch.device)
devices = [torch.device(v) for v in device_map.values()]
Comment on lines +335 to +347
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need that ?

cuda_devices = [d for d in devices if d.type != "cpu"]
target_device = cuda_devices[0] if cuda_devices else torch.device("cpu")
elif isinstance(device_map, str) and device_map not in ("cpu", "auto"):
target_device = torch.device(device_map)
else:
target_device = torch.device("cpu")

autoname_modules(model)
skip_modules = self.quantization_config.skip_modules
hqq_state_dict_keys = HQQLinear(None, None).state_dict_keys()

# Find which modules should be quantized
quantizable_modules = {}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
linear_tag = name_to_linear_tag(name)
if linear_tag not in skip_modules:
quantizable_modules[name] = module

# Load the full state dict from checkpoint files
full_state_dict = {}
for ckpt_file in self._checkpoint_files:
if ckpt_file.endswith(".safetensors"):
with safe_open(ckpt_file, framework="pt") as f:
for k in f.keys():
full_state_dict[k] = f.get_tensor(k)
else:
import torch as torch_

full_state_dict.update(torch_.load(ckpt_file, map_location="cpu", weights_only=True))

# Group state dict by module
module_states = defaultdict(dict)
for key, value in full_state_dict.items():
# Find the module this key belongs to
for module_name in quantizable_modules:
if key.startswith(module_name + "."):
param_name = key[len(module_name) + 1 :]
if param_name in hqq_state_dict_keys:
module_states[module_name][param_name] = value
break

# Replace nn.Linear with HQQLinear for each quantizable module
for module_name, state in module_states.items():
if "W_q" not in state:
continue

hqq_layer = HQQLinear(
None,
None,
compute_dtype=self.dtype or torch.float16,
device="cpu",
initialize=False,
)

state["W_q"] = torch.nn.Parameter(state["W_q"], requires_grad=False)
hqq_layer.load_state_dict(state)

# Move to the correct device (HQQLinear.to() is a no-op, use .cuda() instead)
if target_device.type != "cpu":
hqq_layer.cuda(target_device)

if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)

if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)

parent_name, _, child_name = module_name.rpartition(".")
parent = model.get_submodule(parent_name) if parent_name else model
setattr(parent, child_name, hqq_layer)

del full_state_dict

# Free any leftover GPU memory from replaced nn.Linear modules
import gc

gc.collect()
if target_device.type != "cpu":
torch.cuda.empty_cache()

def is_serializable(self):
return True

Expand Down
Loading
Loading