-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Fix broken HQQ support #45147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix broken HQQ support #45147
Changes from all commits
cec8546
450363d
f8c299f
4d1c5f0
0632a17
183a9ad
555a3aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,10 +59,16 @@ def __init__(self, quantization_config, **kwargs): | |
| ) | ||
| super().__init__(quantization_config, **kwargs) | ||
| self.dtype = None | ||
| self.device_map = None | ||
| self.using_multi_gpu = False | ||
| # Keys that are serialized specifically by hqq | ||
| self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"} | ||
|
|
||
| def update_dtype(self, dtype): | ||
| if dtype is not None: | ||
| self.dtype = dtype | ||
| return dtype | ||
|
Comment on lines
+67
to
+70
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need that, the tensors should be in the right dtype so we should be able to access that directly |
||
|
|
||
| def validate_environment(self, *args, **kwargs): | ||
| if self.dtype is None: | ||
| if "dtype" in kwargs: | ||
|
|
@@ -72,6 +78,7 @@ def validate_environment(self, *args, **kwargs): | |
| logger.info("Setting dtype to torch.float32 as the default value since it was not specified.") | ||
|
Comment on lines
73
to
78
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove that |
||
|
|
||
| device_map = kwargs.get("device_map") | ||
| self.device_map = device_map | ||
| if isinstance(device_map, dict): | ||
| if "cpu" in device_map.values() or "disk" in device_map.values(): | ||
| raise ValueError( | ||
|
|
@@ -144,10 +151,16 @@ def validate_environment(self, *args, **kwargs): | |
| # return list(new_keys) | ||
|
|
||
| def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: | ||
| module, _ = get_module_from_name(model, param_name) | ||
| # Since we do not prepare the modules in advance, we need every param of the Linear layer to go through | ||
| # `create_quantized_param`, even when `self.is_quantized == True` | ||
| return isinstance(module, torch.nn.Linear) | ||
| module, tensor_name = get_module_from_name(model, param_name) | ||
| return isinstance(module, torch.nn.Linear) and tensor_name == "weight" | ||
|
|
||
| def get_quantize_ops(self): | ||
| from ..integrations.hqq import HqqQuantize | ||
|
|
||
| return HqqQuantize(self) | ||
|
|
||
| def get_weight_conversions(self): | ||
| return [] | ||
|
Comment on lines
+162
to
+163
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should use deserialize |
||
|
|
||
| # TODO: to remove | ||
| # def create_quantized_param( | ||
|
|
@@ -232,6 +245,47 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, ** | |
|
|
||
| # setattr(parent_module, node, hqq_layer) | ||
|
|
||
| def _setup_missing_key_filters(self, model, checkpoint_files): | ||
| """Scan checkpoint files to find HQQ-quantized modules. | ||
|
|
||
| For those modules: | ||
| 1. Suppress their .weight missing key warnings in the load report. | ||
| 2. Replace their weight parameter with a scalar meta tensor so that | ||
| ``_move_missing_keys_from_meta_to_device`` does not allocate | ||
| full-size fp16 tensors on GPU (which would cause OOM). | ||
| """ | ||
| import re | ||
|
|
||
| from safetensors import safe_open | ||
|
|
||
| quantized_modules = set() | ||
|
Comment on lines
+248
to
+261
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not do these types of changes ? |
||
| for ckpt_file in checkpoint_files: | ||
| if ckpt_file.endswith(".safetensors"): | ||
| with safe_open(ckpt_file, framework="pt") as f: | ||
| for k in f.keys(): | ||
| if k.endswith(".W_q"): | ||
| quantized_modules.add(k[: -len(".W_q")]) | ||
| else: | ||
| state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True) | ||
| for k in state_dict: | ||
| if k.endswith(".W_q"): | ||
| quantized_modules.add(k[: -len(".W_q")]) | ||
|
|
||
| if quantized_modules: | ||
| # Build regex that matches only .weight keys of quantized modules | ||
| escaped = [re.escape(m) + r"\.weight" for m in quantized_modules] | ||
| existing = model._keys_to_ignore_on_load_missing or [] | ||
| model._keys_to_ignore_on_load_missing = existing + escaped | ||
|
|
||
| # Replace weight params with scalar meta tensors to avoid GPU allocation | ||
| for module_name in quantized_modules: | ||
| try: | ||
| module = model.get_submodule(module_name) | ||
| except AttributeError: | ||
| continue | ||
| if hasattr(module, "weight") and module.weight is not None: | ||
| module.weight = torch.nn.Parameter(torch.empty(0, device="meta"), requires_grad=False) | ||
|
|
||
| def _patch_layer_for_multigpu(self, hqq_layer): | ||
| def forward_with_device(self, x): | ||
| out = torch.matmul(x.to(self.device), self.dequantize().t()) | ||
|
|
@@ -245,17 +299,133 @@ def forward_with_device(self, x): | |
| def _process_model_before_weight_loading( | ||
| self, | ||
| model: "PreTrainedModel", | ||
| checkpoint_files=None, | ||
| **kwargs, | ||
| ): | ||
| # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param(). | ||
| # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config) | ||
| model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) | ||
| if self.pre_quantized: | ||
| # Store checkpoint files for loading in _process_model_after_weight_loading | ||
| self._checkpoint_files = checkpoint_files | ||
|
|
||
| # Suppress noisy load report: HQQ checkpoint keys (W_q, scale, etc.) are | ||
| # "unexpected" and nn.Linear .weight keys are "missing" from the standard | ||
| # loading perspective, but _load_hqq_from_checkpoint handles them. | ||
| hqq_keys = HQQLinear(None, None).state_dict_keys() | ||
| ignore_unexpected = [rf"\.{k}$" for k in hqq_keys] | ||
| existing = model._keys_to_ignore_on_load_unexpected or [] | ||
| model._keys_to_ignore_on_load_unexpected = existing + ignore_unexpected | ||
|
|
||
| # For missing keys: scan checkpoint to find which modules have W_q (are HQQ-quantized), | ||
| # and suppress only their .weight keys. Also replace their weight with a scalar meta | ||
| # tensor to prevent _move_missing_keys_from_meta_to_device from allocating full-size | ||
| # tensors on GPU (which would cause OOM for large models). | ||
| self._setup_missing_key_filters(model, checkpoint_files) | ||
| else: | ||
| # Add the corresponding quant_config to each valid module for on-the-fly quantization. | ||
| # prepare_for_hqq_linear() also sets the right quantization config inside the model | ||
| # (model.config.quantization_config) and the layers (hqq_layer.quant_config) | ||
| model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) | ||
|
|
||
| def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): | ||
| if self.pre_quantized: | ||
| self._load_hqq_from_checkpoint(model) | ||
| setattr(model, "is_hqq_quantized", True) | ||
| setattr(model, "is_hqq_serializable", self.is_serializable()) | ||
| return model | ||
|
|
||
| def _load_hqq_from_checkpoint(self, model: "PreTrainedModel"): | ||
| """Load pre-quantized HQQ weights directly from checkpoint files.""" | ||
| from collections import defaultdict | ||
|
|
||
| from safetensors import safe_open | ||
|
|
||
| from ..integrations.hqq import autoname_modules, name_to_linear_tag | ||
|
|
||
| # Determine target device from stored device_map | ||
| device_map = getattr(self, "device_map", None) | ||
| if isinstance(device_map, dict): | ||
| # Use the first non-cpu device from the map (values can be str, int, or torch.device) | ||
| devices = [torch.device(v) for v in device_map.values()] | ||
|
Comment on lines
+335
to
+347
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we need that ? |
||
| cuda_devices = [d for d in devices if d.type != "cpu"] | ||
| target_device = cuda_devices[0] if cuda_devices else torch.device("cpu") | ||
| elif isinstance(device_map, str) and device_map not in ("cpu", "auto"): | ||
| target_device = torch.device(device_map) | ||
| else: | ||
| target_device = torch.device("cpu") | ||
|
|
||
| autoname_modules(model) | ||
| skip_modules = self.quantization_config.skip_modules | ||
| hqq_state_dict_keys = HQQLinear(None, None).state_dict_keys() | ||
|
|
||
| # Find which modules should be quantized | ||
| quantizable_modules = {} | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, torch.nn.Linear): | ||
| linear_tag = name_to_linear_tag(name) | ||
| if linear_tag not in skip_modules: | ||
| quantizable_modules[name] = module | ||
|
|
||
| # Load the full state dict from checkpoint files | ||
| full_state_dict = {} | ||
| for ckpt_file in self._checkpoint_files: | ||
| if ckpt_file.endswith(".safetensors"): | ||
| with safe_open(ckpt_file, framework="pt") as f: | ||
| for k in f.keys(): | ||
| full_state_dict[k] = f.get_tensor(k) | ||
| else: | ||
| import torch as torch_ | ||
|
|
||
| full_state_dict.update(torch_.load(ckpt_file, map_location="cpu", weights_only=True)) | ||
|
|
||
| # Group state dict by module | ||
| module_states = defaultdict(dict) | ||
| for key, value in full_state_dict.items(): | ||
| # Find the module this key belongs to | ||
| for module_name in quantizable_modules: | ||
| if key.startswith(module_name + "."): | ||
| param_name = key[len(module_name) + 1 :] | ||
| if param_name in hqq_state_dict_keys: | ||
| module_states[module_name][param_name] = value | ||
| break | ||
|
|
||
| # Replace nn.Linear with HQQLinear for each quantizable module | ||
| for module_name, state in module_states.items(): | ||
| if "W_q" not in state: | ||
| continue | ||
|
|
||
| hqq_layer = HQQLinear( | ||
| None, | ||
| None, | ||
| compute_dtype=self.dtype or torch.float16, | ||
| device="cpu", | ||
| initialize=False, | ||
| ) | ||
|
|
||
| state["W_q"] = torch.nn.Parameter(state["W_q"], requires_grad=False) | ||
| hqq_layer.load_state_dict(state) | ||
|
|
||
| # Move to the correct device (HQQLinear.to() is a no-op, use .cuda() instead) | ||
| if target_device.type != "cpu": | ||
| hqq_layer.cuda(target_device) | ||
|
|
||
| if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): | ||
| hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) | ||
|
|
||
| if self.using_multi_gpu: | ||
| hqq_layer = self._patch_layer_for_multigpu(hqq_layer) | ||
|
|
||
| parent_name, _, child_name = module_name.rpartition(".") | ||
| parent = model.get_submodule(parent_name) if parent_name else model | ||
| setattr(parent, child_name, hqq_layer) | ||
|
|
||
| del full_state_dict | ||
|
|
||
| # Free any leftover GPU memory from replaced nn.Linear modules | ||
| import gc | ||
|
|
||
| gc.collect() | ||
| if target_device.type != "cpu": | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def is_serializable(self): | ||
| return True | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if we didn't have to do that here