Fix broken HQQ support#45147
Conversation
|
@ArthurZucker @SunMarc a little bump on this, should be an easy fix |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: hqq |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks for looking into that @mobicham ! Since we are adding back the support, can we try to do something a bit more maintainable ? This is also one of the reason we didn't add back the support, it was too complicated. I know that SINQ based a lot of their code on hqq / gemlite and the first PR they added did something similar to here but in the end they manage to clean a lot the integration and now it looks much better. Would you be up to do that ? https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_sinq.py
| # TODO: to remove | ||
| # def create_quantized_param( | ||
| # self, | ||
| # model: "PreTrainedModel", |
| def get_weight_conversions(self): | ||
| return [] |
| # TODO: to remove | ||
| # Kept here in case we see some interest in adding support for it | ||
| # # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear | ||
| # def update_expected_keys( |
| @@ -72,6 +78,7 @@ | |||
| logger.info("Setting dtype to torch.float32 as the default value since it was not specified.") | |||
| def update_dtype(self, dtype): | ||
| if dtype is not None: | ||
| self.dtype = dtype | ||
| return dtype |
There was a problem hiding this comment.
do we really need that, the tensors should be in the right dtype so we should be able to access that directly
| def _load_hqq_from_checkpoint(self, model: "PreTrainedModel"): | ||
| """Load pre-quantized HQQ weights directly from checkpoint files.""" | ||
| from collections import defaultdict | ||
|
|
||
| from safetensors import safe_open | ||
|
|
||
| from ..integrations.hqq import autoname_modules, name_to_linear_tag | ||
|
|
||
| # Determine target device from stored device_map | ||
| device_map = getattr(self, "device_map", None) | ||
| if isinstance(device_map, dict): | ||
| # Use the first non-cpu device from the map (values can be str, int, or torch.device) | ||
| devices = [torch.device(v) for v in device_map.values()] |
| def _setup_missing_key_filters(self, model, checkpoint_files): | ||
| """Scan checkpoint files to find HQQ-quantized modules. | ||
|
|
||
| For those modules: | ||
| 1. Suppress their .weight missing key warnings in the load report. | ||
| 2. Replace their weight parameter with a scalar meta tensor so that | ||
| ``_move_missing_keys_from_meta_to_device`` does not allocate | ||
| full-size fp16 tensors on GPU (which would cause OOM). | ||
| """ | ||
| import re | ||
|
|
||
| from safetensors import safe_open | ||
|
|
||
| quantized_modules = set() |
There was a problem hiding this comment.
Can we not do these types of changes ?
| quant_config = getattr(module, "quant_config", None) | ||
| if quant_config is None: | ||
| # Module is skipped from quantization, just return the weight as-is | ||
| return {full_layer_name: value} | ||
|
|
||
| # Determine target device and compute dtype | ||
| target_device = value.device | ||
| compute_dtype = self.hf_quantizer.dtype | ||
|
|
||
| # Create HQQLinear from the nn.Linear | ||
| hqq_layer = HQQLinear( | ||
| module, | ||
| quant_config=quant_config, | ||
| compute_dtype=compute_dtype, | ||
| device=target_device, | ||
| del_orig=True, | ||
| ) | ||
|
|
||
| if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): | ||
| hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) | ||
|
|
||
| if self.hf_quantizer.using_multi_gpu: |
There was a problem hiding this comment.
It would be nice if we didn't have to do that here
What does this PR do?
This PR fixes hqq support that has been broken for a couple of months now after a refactoring:
Code Agent Policy
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @SunMarc