diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index 083ec53a2fd3..f83007410f7d 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -127,3 +127,135 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve logger.warning("No linear modules were found in your model for quantization.") return model + + +class HqqQuantize: + """HQQ quantization operation for the new weight loading flow.""" + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict, + full_layer_name=None, + model=None, + **kwargs, + ): + from hqq.core.quantize import HQQLinear + + from ..quantizers.quantizers_utils import get_module_from_name + + # input_dict has {param_name: [tensor]} for the weight + value = list(input_dict.values())[0] + value = value[0] if isinstance(value, list) else value + + # full_layer_name is e.g. "model.layers.0.self_attn.q_proj.weight" + module_name = full_layer_name.rsplit(".", 1)[0] + module, _ = get_module_from_name(model, full_layer_name) + + # Load weight into the nn.Linear module + module.weight = torch.nn.Parameter(value, requires_grad=False) + + # Get the quant_config that was set in _process_model_before_weight_loading + quant_config = getattr(module, "quant_config", None) + if quant_config is None: + # Module is skipped from quantization, just return the weight as-is + return {full_layer_name: value} + + # Determine target device and compute dtype + target_device = value.device + compute_dtype = self.hf_quantizer.dtype + + # Create HQQLinear from the nn.Linear + hqq_layer = HQQLinear( + module, + quant_config=quant_config, + compute_dtype=compute_dtype, + device=target_device, + del_orig=True, + ) + + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + + if self.hf_quantizer.using_multi_gpu: + hqq_layer = self.hf_quantizer._patch_layer_for_multigpu(hqq_layer) + + # Replace the module in the model + parent_module_name, _, child_name = module_name.rpartition(".") + parent_module = model.get_submodule(parent_module_name) if parent_module_name else model + setattr(parent_module, child_name, hqq_layer) + + # Mark as loaded so it's not reported as missing + missing_keys = kwargs.get("missing_keys") + if missing_keys is not None: + missing_keys.discard(full_layer_name) + + # Return empty dict so the loading code doesn't try to set params + return {} + + +class HqqDeserialize: + """Deserialize HQQ pre-quantized weights into an HQQLinear module.""" + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict, + full_layer_name=None, + model=None, + **kwargs, + ): + from hqq.core.quantize import HQQLinear + + # Unwrap list values + state_dict = {} + for key, value in input_dict.items(): + state_dict[key] = value[0] if isinstance(value, list) else value + + # If W_q is not present, this is not an HQQ-quantized layer — pass through + if "W_q" not in state_dict: + return input_dict + + # full_layer_name is e.g. "model.layers.0.self_attn.v_proj.weight" + # (target pattern "weight" appended to module path) + module_name = full_layer_name.rsplit(".", 1)[0] + + parent_name, _, child_name = module_name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + + # Create empty HQQLinear + hqq_layer = HQQLinear( + None, + None, + compute_dtype=self.hf_quantizer.dtype or torch.float16, + device="cpu", + initialize=False, + ) + + # Make W_q an nn.Parameter as HQQ expects + if "W_q" in state_dict: + state_dict["W_q"] = torch.nn.Parameter(state_dict["W_q"], requires_grad=False) + + hqq_layer.load_state_dict(state_dict) + + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + + if self.hf_quantizer.using_multi_gpu: + hqq_layer = self.hf_quantizer._patch_layer_for_multigpu(hqq_layer) + + setattr(parent, child_name, hqq_layer) + + # Mark weight and bias as loaded + missing_keys = kwargs.get("missing_keys") + if missing_keys is not None: + missing_keys.discard(full_layer_name) + # Also discard bias since HQQLinear handles it internally + bias_key = module_name + ".bias" + missing_keys.discard(bias_key) + + return {} diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 05dce3d996a0..43238e99e7e6 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -59,10 +59,16 @@ def __init__(self, quantization_config, **kwargs): ) super().__init__(quantization_config, **kwargs) self.dtype = None + self.device_map = None self.using_multi_gpu = False # Keys that are serialized specifically by hqq self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"} + def update_dtype(self, dtype): + if dtype is not None: + self.dtype = dtype + return dtype + def validate_environment(self, *args, **kwargs): if self.dtype is None: if "dtype" in kwargs: @@ -72,6 +78,7 @@ def validate_environment(self, *args, **kwargs): logger.info("Setting dtype to torch.float32 as the default value since it was not specified.") device_map = kwargs.get("device_map") + self.device_map = device_map if isinstance(device_map, dict): if "cpu" in device_map.values() or "disk" in device_map.values(): raise ValueError( @@ -144,10 +151,16 @@ def validate_environment(self, *args, **kwargs): # return list(new_keys) def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: - module, _ = get_module_from_name(model, param_name) - # Since we do not prepare the modules in advance, we need every param of the Linear layer to go through - # `create_quantized_param`, even when `self.is_quantized == True` - return isinstance(module, torch.nn.Linear) + module, tensor_name = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) and tensor_name == "weight" + + def get_quantize_ops(self): + from ..integrations.hqq import HqqQuantize + + return HqqQuantize(self) + + def get_weight_conversions(self): + return [] # TODO: to remove # def create_quantized_param( @@ -232,6 +245,47 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, ** # setattr(parent_module, node, hqq_layer) + def _setup_missing_key_filters(self, model, checkpoint_files): + """Scan checkpoint files to find HQQ-quantized modules. + + For those modules: + 1. Suppress their .weight missing key warnings in the load report. + 2. Replace their weight parameter with a scalar meta tensor so that + ``_move_missing_keys_from_meta_to_device`` does not allocate + full-size fp16 tensors on GPU (which would cause OOM). + """ + import re + + from safetensors import safe_open + + quantized_modules = set() + for ckpt_file in checkpoint_files: + if ckpt_file.endswith(".safetensors"): + with safe_open(ckpt_file, framework="pt") as f: + for k in f.keys(): + if k.endswith(".W_q"): + quantized_modules.add(k[: -len(".W_q")]) + else: + state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True) + for k in state_dict: + if k.endswith(".W_q"): + quantized_modules.add(k[: -len(".W_q")]) + + if quantized_modules: + # Build regex that matches only .weight keys of quantized modules + escaped = [re.escape(m) + r"\.weight" for m in quantized_modules] + existing = model._keys_to_ignore_on_load_missing or [] + model._keys_to_ignore_on_load_missing = existing + escaped + + # Replace weight params with scalar meta tensors to avoid GPU allocation + for module_name in quantized_modules: + try: + module = model.get_submodule(module_name) + except AttributeError: + continue + if hasattr(module, "weight") and module.weight is not None: + module.weight = torch.nn.Parameter(torch.empty(0, device="meta"), requires_grad=False) + def _patch_layer_for_multigpu(self, hqq_layer): def forward_with_device(self, x): out = torch.matmul(x.to(self.device), self.dequantize().t()) @@ -245,17 +299,133 @@ def forward_with_device(self, x): def _process_model_before_weight_loading( self, model: "PreTrainedModel", + checkpoint_files=None, **kwargs, ): - # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param(). - # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config) - model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) + if self.pre_quantized: + # Store checkpoint files for loading in _process_model_after_weight_loading + self._checkpoint_files = checkpoint_files + + # Suppress noisy load report: HQQ checkpoint keys (W_q, scale, etc.) are + # "unexpected" and nn.Linear .weight keys are "missing" from the standard + # loading perspective, but _load_hqq_from_checkpoint handles them. + hqq_keys = HQQLinear(None, None).state_dict_keys() + ignore_unexpected = [rf"\.{k}$" for k in hqq_keys] + existing = model._keys_to_ignore_on_load_unexpected or [] + model._keys_to_ignore_on_load_unexpected = existing + ignore_unexpected + + # For missing keys: scan checkpoint to find which modules have W_q (are HQQ-quantized), + # and suppress only their .weight keys. Also replace their weight with a scalar meta + # tensor to prevent _move_missing_keys_from_meta_to_device from allocating full-size + # tensors on GPU (which would cause OOM for large models). + self._setup_missing_key_filters(model, checkpoint_files) + else: + # Add the corresponding quant_config to each valid module for on-the-fly quantization. + # prepare_for_hqq_linear() also sets the right quantization config inside the model + # (model.config.quantization_config) and the layers (hqq_layer.quant_config) + model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + if self.pre_quantized: + self._load_hqq_from_checkpoint(model) setattr(model, "is_hqq_quantized", True) setattr(model, "is_hqq_serializable", self.is_serializable()) return model + def _load_hqq_from_checkpoint(self, model: "PreTrainedModel"): + """Load pre-quantized HQQ weights directly from checkpoint files.""" + from collections import defaultdict + + from safetensors import safe_open + + from ..integrations.hqq import autoname_modules, name_to_linear_tag + + # Determine target device from stored device_map + device_map = getattr(self, "device_map", None) + if isinstance(device_map, dict): + # Use the first non-cpu device from the map (values can be str, int, or torch.device) + devices = [torch.device(v) for v in device_map.values()] + cuda_devices = [d for d in devices if d.type != "cpu"] + target_device = cuda_devices[0] if cuda_devices else torch.device("cpu") + elif isinstance(device_map, str) and device_map not in ("cpu", "auto"): + target_device = torch.device(device_map) + else: + target_device = torch.device("cpu") + + autoname_modules(model) + skip_modules = self.quantization_config.skip_modules + hqq_state_dict_keys = HQQLinear(None, None).state_dict_keys() + + # Find which modules should be quantized + quantizable_modules = {} + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + linear_tag = name_to_linear_tag(name) + if linear_tag not in skip_modules: + quantizable_modules[name] = module + + # Load the full state dict from checkpoint files + full_state_dict = {} + for ckpt_file in self._checkpoint_files: + if ckpt_file.endswith(".safetensors"): + with safe_open(ckpt_file, framework="pt") as f: + for k in f.keys(): + full_state_dict[k] = f.get_tensor(k) + else: + import torch as torch_ + + full_state_dict.update(torch_.load(ckpt_file, map_location="cpu", weights_only=True)) + + # Group state dict by module + module_states = defaultdict(dict) + for key, value in full_state_dict.items(): + # Find the module this key belongs to + for module_name in quantizable_modules: + if key.startswith(module_name + "."): + param_name = key[len(module_name) + 1 :] + if param_name in hqq_state_dict_keys: + module_states[module_name][param_name] = value + break + + # Replace nn.Linear with HQQLinear for each quantizable module + for module_name, state in module_states.items(): + if "W_q" not in state: + continue + + hqq_layer = HQQLinear( + None, + None, + compute_dtype=self.dtype or torch.float16, + device="cpu", + initialize=False, + ) + + state["W_q"] = torch.nn.Parameter(state["W_q"], requires_grad=False) + hqq_layer.load_state_dict(state) + + # Move to the correct device (HQQLinear.to() is a no-op, use .cuda() instead) + if target_device.type != "cpu": + hqq_layer.cuda(target_device) + + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + + if self.using_multi_gpu: + hqq_layer = self._patch_layer_for_multigpu(hqq_layer) + + parent_name, _, child_name = module_name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + setattr(parent, child_name, hqq_layer) + + del full_state_dict + + # Free any leftover GPU memory from replaced nn.Linear modules + import gc + + gc.collect() + if target_device.type != "cpu": + torch.cuda.empty_cache() + def is_serializable(self): return True diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 913bf6bf9e75..ad2797229fa5 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -14,7 +14,6 @@ import gc import unittest -from unittest import skip import accelerate @@ -106,7 +105,6 @@ def test_to_dict(self): @require_torch_accelerator @require_accelerate @require_hqq -@skip("skip for now until we add back support") class HQQTest(unittest.TestCase): def tearDown(self): cleanup() @@ -164,7 +162,6 @@ def test_quantized_model_fake_weight_dtype(self): @require_torch_multi_accelerator @require_accelerate @require_hqq -@skip("skip for now until we add back support") class HQQTestMultiGPU(unittest.TestCase): def tearDown(self): cleanup() @@ -188,7 +185,6 @@ def test_fp16_quantized_model_multipgpu(self): @require_torch_accelerator @require_accelerate @require_hqq -@skip("skip for now until we add back support") class HQQTestBias(unittest.TestCase): def tearDown(self): cleanup() @@ -245,7 +241,6 @@ def test_save_and_load_quantized_model(self): @require_torch_accelerator @require_accelerate @require_hqq -@skip("skip for now until we add back support") class HQQSerializationTest(unittest.TestCase): def tearDown(self): cleanup()