diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 8cb79004d..0fe576c75 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -38,6 +38,7 @@ class BasicArgumentParser(argparse.ArgumentParser): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_argument( @@ -750,45 +751,10 @@ def tune(args): rotation_config=rot_config, ) - model_name = args.model.rstrip("/") - - if model_name.split("/")[-1].strip(".") == "" and "gguf" not in args.format: - if autoround.group_size <= 0: - if "fp" in autoround.act_data_type: - suffix = f"afp{autoround.act_bits}" - else: - suffix = f"a{autoround.act_bits}" - else: - suffix = f"g{autoround.group_size}" - export_dir = os.path.join(args.output_dir, f"w{autoround.bits}{suffix}") - elif model_name.split("/")[-1].strip(".") == "" and "gguf" in args.format: - export_dir = args.output_dir - elif model_name.split("./")[-1].strip("./") != "" and "gguf" in args.format: - export_dir = os.path.join(args.output_dir, model_name.split("/")[-1] + "-gguf") - else: - if isinstance(autoround.group_size, tuple): - assert len(autoround.group_size) == 2, f"Only support 2D group_size, but get {autoround.group_size}" - suffix = f"g{autoround.group_size[0]}x{autoround.group_size[1]}" - else: - if autoround.group_size <= 0: - if "fp" in autoround.act_data_type: - suffix = f"afp{autoround.act_bits}" - else: - suffix = f"a{autoround.act_bits}" - else: - suffix = f"g{autoround.group_size}" - prefix = ( - autoround.data_type.lower().replace("_", "") - if "int" not in autoround.data_type or "mx" in autoround.data_type - else "" - ) - export_dir = os.path.join( - args.output_dir, - model_name.split("/")[-1] + (f"-{prefix}" if prefix else "") + f"-w{autoround.bits}{suffix}", - ) - # ======================= Quantize and save model ======================= - model, folders = autoround.quantize_and_save(export_dir, format=args.format) # pylint: disable=E1101 + # Export directory is now derived automatically inside quantize_and_save via + # BaseCompressor._get_export_dir(), so we only need to pass the base output_dir. + model, folders = autoround.quantize_and_save(args.output_dir, format=args.format) # pylint: disable=E1101 tokenizer = autoround.tokenizer # pylint: disable=E1101 model.eval() diff --git a/auto_round/algorithms/__init__.py b/auto_round/algorithms/__init__.py new file mode 100644 index 000000000..14a492441 --- /dev/null +++ b/auto_round/algorithms/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/auto_round/algorithms/alg_config.py b/auto_round/algorithms/alg_config.py new file mode 100644 index 000000000..d9d5f0c75 --- /dev/null +++ b/auto_round/algorithms/alg_config.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class AlgConfig: + def __init__(self): + pass diff --git a/auto_round/algorithms/base.py b/auto_round/algorithms/base.py new file mode 100644 index 000000000..4590536c5 --- /dev/null +++ b/auto_round/algorithms/base.py @@ -0,0 +1,17 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BaseAlgorithm: + pass diff --git a/auto_round/algorithms/quantization/__init__.py b/auto_round/algorithms/quantization/__init__.py new file mode 100644 index 000000000..6a727f31b --- /dev/null +++ b/auto_round/algorithms/quantization/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_round.algorithms.quantization.base import BaseQuantizers +from auto_round.algorithms.quantization.config import QuantizationConfig +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig +from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer +from auto_round.algorithms.quantization.adam_round.adam import AdamRoundQuantizer +from auto_round.algorithms.quantization.rtn.config import RTNConfig +from auto_round.algorithms.quantization.rtn.quantizer import RTNQuantizer, OptimizedRTNQuantizer diff --git a/auto_round/algorithms/quantization/adam_round/__init__.py b/auto_round/algorithms/quantization/adam_round/__init__.py new file mode 100644 index 000000000..14a492441 --- /dev/null +++ b/auto_round/algorithms/quantization/adam_round/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/auto_round/algorithms/quantization/adam_round/adam.py b/auto_round/algorithms/quantization/adam_round/adam.py new file mode 100644 index 000000000..96835b533 --- /dev/null +++ b/auto_round/algorithms/quantization/adam_round/adam.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +import torch + +from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer +from auto_round.schemes import QuantizationScheme +from auto_round.utils import check_is_cpu, htcore, is_hpex_available + + +class AdamRoundQuantizer(SignRoundQuantizer): + + def __init__(self, config): + super().__init__(config) + self.momentum = None # AdamW handles momentum internally + + def _get_optimizer(self, optimizer): + if optimizer is None: + optimizer = torch.optim.AdamW + elif isinstance(optimizer, str): + optimizer = getattr(torch.optim, optimizer) + else: + optimizer = optimizer + return optimizer + + def _get_scaler(self): + scaler = None + if self.model_context.amp and not check_is_cpu(self.compress_context.device): + from torch.cuda.amp import GradScaler + + scaler = GradScaler(init_scale=1024, growth_interval=100000) + return scaler + + def _scale_loss_and_backward(self, scaler, loss): + if scaler is not None: + loss = scaler.scale(loss) + + loss.backward() + if is_hpex_available(): + htcore.mark_step() + return loss + + def _step(self, scaler, optimizer, lr_schedule): + if scaler is not None: + scaler.step(optimizer) + optimizer.zero_grad() + lr_schedule.step() + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + lr_schedule.step() + if is_hpex_available(): + htcore.mark_step() diff --git a/auto_round/algorithms/quantization/base.py b/auto_round/algorithms/quantization/base.py new file mode 100644 index 000000000..12d972554 --- /dev/null +++ b/auto_round/algorithms/quantization/base.py @@ -0,0 +1,562 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import traceback +from collections import defaultdict +from typing import Union + +import torch + +from auto_round.algorithms.quantization.config import QuantizationConfig +from auto_round.compressors_new.utils import ( + block_forward, + check_need_act_calibration, +) +from auto_round.data_type import QUANT_FUNC_WITH_DTYPE +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size +from auto_round.logger import logger +from auto_round.utils import ( + INNER_SUPPORTED_LAYER_TYPES, + SUPPORTED_LAYER_TYPES, + check_to_quantized, + clear_memory, + compile_func, +) + + +class BaseQuantizers: + # Class-level attribute declarations for convenient access in quantization methods. + # Scheme-related attrs (layer_config, scale_dtype, has_qlayer_outside_block, etc.) + # are resolved by SchemeMixin in BaseCompressor and synced here after post_init(). + model_context = None + compress_context = None + dataset = None + supported_types = SUPPORTED_LAYER_TYPES + inner_supported_types = INNER_SUPPORTED_LAYER_TYPES + enable_alg_ext = False + # Subclasses that support diffusion models should override this with the + # appropriate output key mapping, e.g.: + # DIFFUSION_OUTPUT_CONFIGS = {"FluxTransformerBlock": ["encoder_hidden_states", "hidden_states"]} + DIFFUSION_OUTPUT_CONFIGS: dict = { + "FluxTransformerBlock": ["encoder_hidden_states", "hidden_states"], + "FluxSingleTransformerBlock": ["encoder_hidden_states", "hidden_states"], + "OvisImageTransformerBlock": ["encoder_hidden_states", "hidden_states"], + "OvisImageSingleTransformerBlock": ["encoder_hidden_states", "hidden_states"], + } + + def __init__(self, config: QuantizationConfig): + self.config = config + self.layer_config = None + self.bits = config.bits + self.group_size = config.group_size + self.sym = config.sym + self.data_type = config.data_type + self.act_bits = config.act_bits + self.act_group_size = config.act_group_size + self.act_sym = config.act_sym + self.act_data_type = config.act_data_type + self.act_dynamic = config.act_dynamic + self.super_bits = config.super_bits + self.super_group_size = config.super_group_size + self.scale_dtype = config.scale_dtype + self.ignore_layers = config.ignore_layers + self.quant_lm_head = config.quant_lm_head + self.to_quant_block_names = config.to_quant_block_names + # Calibration / sampling attrs – synced from compressor in post_init. + self.seqlen = 2048 + self.nsamples = 128 + self.batch_size = getattr(config, "batch_size", 8) + self.batch_dim = getattr(config, "batch_dim", None) + self.infer_bs_coeff = getattr(config, "infer_bs_coeff", 1) + # Whether to feed quantized-block outputs as inputs to the next block. + # Subclasses that support cascaded quantized-input (e.g. SignRoundQuantizer) + # override this from their config. Defaults to False for zero-shot algorithms + # (RTN) where activations are not used during weight optimization. + self.enable_quanted_input = getattr(config, "enable_quanted_input", False) + + @classmethod + def from_config(cls, config: QuantizationConfig): + if cls.__name__ == config._alg_cls: + return cls(config) + else: + module = importlib.import_module("auto_round.algorithms.quantization") + alg_cls = getattr(module, config._alg_cls) + return alg_cls(config) + + @property + def formats(self): + return getattr(self.compress_context, "formats", None) + + @property + def amp(self): + return getattr(self.model_context, "amp", False) + + @property + def amp_dtype(self): + import torch + + return getattr(self.model_context, "amp_dtype", torch.float32) + + def _register_act_max_hook(self, model): + + def get_act_max_hook(module, input, output): + if isinstance(input, (tuple, list)): + input = input[0] + if input.numel() == 0: + return # as no needs for act_max update + input, _, _ = reshape_pad_tensor_by_group_size(input, self.act_group_size) + act_max = torch.max(torch.abs(input), dim=-1).values + if not hasattr(module, "act_max") or module.act_max.numel() == 0: + module.act_max = act_max + if self.config.is_act_nv_fp: ## for nvfp per-tensor input_global_scale calculation usage + max_val = act_max.max() + module.act_max = max_val.unsqueeze(0) if max_val.dim() == 0 else max_val + else: + act_max = act_max.to(module.act_max.device) + if self.config.is_act_nv_fp: ## for nvfp per-tensor input_global_scale calculation usage + max_val = torch.max(act_max.max(), module.act_max.max()) + module.act_max = max_val.unsqueeze(0) if max_val.dim() == 0 else max_val + else: + module.act_max = torch.max(act_max, module.act_max) + + hook_handles = [] + # for single layers out of blocks, like lm_head + if isinstance(model, SUPPORTED_LAYER_TYPES): + m = model + if ( + hasattr(m, "act_dynamic") + and check_need_act_calibration(m.act_dynamic, m.act_data_type, m.act_bits) + and check_to_quantized(m) + ): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + return hook_handles + + for n, m in model.named_modules(): + if ( + hasattr(m, "act_dynamic") + and check_need_act_calibration(m.act_dynamic, m.act_data_type, m.act_bits) + and check_to_quantized(m) + ): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + continue + + # for whole model, RTN + if n in self.layer_config: + config = self.layer_config[n] + act_dynamic = config.get("act_dynamic", True) + act_data_type = config.get("act_data_type", None) + act_bits = config.get("act_bits", 16) + if ( + config["bits"] <= 8 + and check_need_act_calibration(act_dynamic, act_data_type, act_bits) + and check_to_quantized(config) + ): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + continue + return hook_handles + + @torch.inference_mode() + def _quantize_embedding_layer(self): + """Quantizes embedding layers in the model according to the configuration. + + This method iterates through all modules in the model, identifies embedding + layers specified in `self.quantizer.layer_config`, and applies the appropriate quantization + function based on bit precision, grouping strategy, and dtype. + + Returns: + bool: True if the quantization process completes without critical errors. + """ + is_quantized = False + for name, module in self.model_context.model.named_modules(): + # Skip non-Embedding modules or layers not in config + if not isinstance(module, torch.nn.Embedding) or name not in self.layer_config: + continue + + config = self.layer_config[name] + + # Skip layers that are not marked for quantization + if not check_to_quantized(config): + continue + is_quantized = True + config["scale_dtype"] = self.scale_dtype + dtype = config["data_type"] + + # Determine quantization function key with symmetry/asymmetry + if dtype not in QUANT_FUNC_WITH_DTYPE: + dtype = f"{dtype}_{'sym' if config['sym'] else 'asym'}" + + quant_func = QUANT_FUNC_WITH_DTYPE[dtype] + dtype = module.weight.dtype + # As typically float32 are used in RTN to search scale zp, + # to avoid cache a bf16 copy we'd better use float32 + if config.get("super_group_size", None) is not None: + dtype = torch.float32 + + # Attempt quantization on GPU, fall back to CPU if OOM + try: + weight, scale, zp = quant_func( + module.weight.to(dtype=dtype, device=self.compress_context.device), + **{ + k: config.get(k, None) + for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"] + }, + ) + except torch.OutOfMemoryError: + cuda_error_msg = traceback.format_exc() + try: + logger.error(cuda_error_msg) + logger.warning("falling back to CPU") + weight, scale, zp = quant_func( + module.weight.to("cpu"), + **{ + k: config.get(k, None) + for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"] + }, + ) + except Exception as e: + raise + + # Overwrite the module's weights with the quantized version + module.weight.data.copy_(weight.cpu()) + + # Attach scale and zero point (zp) to the module + for param_name, value in zip(["scale", "zp"], [scale, zp]): + if isinstance(value, dict): + for k, v in value.items(): + setattr(module, k if k == "scale" else f"w_{k}", v.cpu()) + elif isinstance(value, torch.Tensor): + setattr(module, param_name, value.cpu()) + else: + setattr(module, param_name, value) + + # Update config + self.layer_config.setdefault(name, {}).update(config) + del weight + del scale + del zp + clear_memory(device_list=self.compress_context.device_list) + + return is_quantized + + def quantize_block( + self, block: torch.nn.Module, input_ids=None, input_others=None, reference_output=None, **kwargs + ) -> dict: + """Apply the quantization algorithm to a prepared block. + + This is the **pure-algorithm** entry point called by the Compressor after + all infrastructure work (device placement, data collection, act-max hook + registration, DDP setup) has been completed. + + Implementations should: + - Perform the algorithm-specific weight/activation quantization on ``block``. + - Return a dict of best parameters (may be empty for zero-shot algorithms). + + Args: + block: Module already placed on the correct device(s). + input_ids: Calibration inputs on cache_device (None for zero-shot RTN). + input_others: Additional inputs (None for zero-shot RTN). + reference_output: FP reference outputs collected by Compressor + (None for algorithms that don't need a reconstruction loss). + **kwargs: Algorithm-specific keyword arguments (e.g. ``loss_device``, + ``card_0_in_high_risk`` for SignRoundQuantizer). + + Returns: + dict: Best quantization parameters found, or ``{}`` if not applicable. + """ + raise NotImplementedError("quantize_block must be implemented in subclasses of BaseQuantizers") + + def quantize_layer(self, layer_name: str, **kwargs): + """Quantizes a single layer of the model. + + Args: + layer_name (str): The name of the layer to quantize. The layer module is + retrieved internally via get_module(model, layer_name). + """ + raise NotImplementedError("quantize_layer must be implemented in subclasses of BaseQuantizers") + + def quantize_layer_outside_block(self, layer_name: str, **kwargs): + """Quantizes a single layer of the model outside of a block. + + Args: + layer_name (str): The name of the layer to quantize. The layer module is + retrieved internally via get_module(model, layer_name). + """ + raise NotImplementedError("quantize_layer_outside_block must be implemented in subclasses of BaseQuantizers") + + @torch.no_grad() + def _get_block_outputs( + self, + block: torch.nn.Module, + input_ids, + input_others, + bs: int, + save_output: bool = True, + ): + """Compute the output of a block for calibration inputs. + + Shared by SignRoundQuantizer and OptimizedRTNQuantizer. Algorithm-specific + block-forward selection (compile vs. plain) is handled here based on + ``enable_alg_ext`` and act-quantization flags. + """ + diffusion_fn = getattr(self, "_get_diffusion_block_outputs", None) + if getattr(self.model_context, "is_diffusion", False): + return self._get_diffusion_block_outputs( + block, + input_ids, + input_others, + bs, + self.compress_context.device, + self.compress_context.cache_device, + ) + + _bf = self._resolve_block_forward() + + output = [] + nsamples = len(input_ids) + for i in range(0, nsamples, bs): + end_index = min(nsamples, i + bs) + indices = torch.arange(i, end_index).to(torch.long) + tmp_input_ids, tmp_input_others = self._sampling_inputs( + input_ids, + input_others, + indices, + self.seqlen, + self.batch_dim, + share_cache_keys=self.model_context.shared_cache_keys, + ) + tmp_output = _bf( + block, + tmp_input_ids, + tmp_input_others, + self.model_context.amp, + self.model_context.amp_dtype, + self.compress_context.device, + ).to(self.compress_context.cache_device) + if save_output: + if self.batch_size == 1: + output.append(tmp_output) + else: + output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) + self.compress_context.clear_memory() + + return output + + def _resolve_block_forward(self): + """Resolve and cache the block forward function once. + + This avoids repeated attribute checks in the hot training loop + (called thousands of times per block). + + For activation-quantization schemes (e.g. FP8_STATIC) or when + algorithm extensions are enabled, forward hooks are attached to layers + inside the block. ``torch.compile`` is incompatible with these hooks, + so we must fall back to the plain ``block_forward``. This mirrors the + old-arch behaviour where ``self.block_forward`` was set in ``__init__`` + to the uncompiled function for these cases. + """ + cached = self.__dict__.get("_resolved_block_forward") + if cached is not None: + return cached + # Act-quantization hooks / alg-extension hooks are incompatible with + # torch.compile → always use the plain (uncompiled) block_forward. + if ( + self.config.is_act_quantize and (not self.config.act_dynamic or self.config.is_act_nv_fp) + ) or self.enable_alg_ext: + self._resolved_block_forward = block_forward + elif self.compress_context.enable_torch_compile: + compiled = self.__dict__.get("_compiled_block_forward") + if compiled is None: + compiled = compile_func(block_forward, self.compress_context.device) + self._compiled_block_forward = compiled + self._resolved_block_forward = compiled + else: + self._resolved_block_forward = block_forward + return self._resolved_block_forward + + def _invalidate_block_forward_cache(self): + """Clear the cached block forward function (call when block changes).""" + self.__dict__.pop("_resolved_block_forward", None) + self.__dict__.pop("_compiled_block_forward", None) + + def _get_current_q_output( + self, + block: torch.nn.Module, + input_ids, + input_others: dict, + indices, + device, + cache_device: str = "cpu", + ) -> torch.Tensor: + """Compute block output for a mini-batch selected by *indices* (used during training). + + Handles both LLM and diffusion model block formats. Uses the compiled + block_forward when enable_torch_compile is True (same as _get_block_outputs), + matching old-arch behaviour where self.block_forward was compiled at init. + """ + current_input_ids, current_input_others = self._sampling_inputs( + input_ids, + input_others, + indices, + seqlen=self.seqlen, + batch_dim=self.batch_dim, + share_cache_keys=self.model_context.shared_cache_keys, + ) + _bf = self._resolve_block_forward() + + if getattr(self.model_context, "is_diffusion", False): + output_config = self.DIFFUSION_OUTPUT_CONFIGS.get(block.__class__.__name__, ["hidden_states"]) + idx = None if "hidden_states" not in output_config else output_config.index("hidden_states") + if isinstance(current_input_ids, dict): + hidden_states = current_input_ids.pop("hidden_states") + current_input_others.update(current_input_ids) + current_input_ids = hidden_states + output_q = _bf( + block, + current_input_ids, + current_input_others, + self.model_context.amp, + self.model_context.amp_dtype, + device, + idx, + ) + else: + output_q = _bf( + block, + current_input_ids, + current_input_others, + self.model_context.amp, + self.model_context.amp_dtype, + device, + ) + return output_q.to(cache_device) + + @classmethod + @torch.no_grad() + def _sampling_inputs( + cls, + input_ids: Union[list[torch.Tensor], dict], + input_others: dict, + indices, + seqlen: int, + batch_dim: int = 0, + share_cache_keys: tuple = (), + ): + """Sample a mini-batch of calibration inputs by indices. + + Shared by SignRoundQuantizer and OptimizedRTNQuantizer. + """ + if isinstance(input_ids, list): + current_input_ids = [input_ids[i] for i in indices] + current_input_ids = torch.cat(current_input_ids, dim=batch_dim) + elif isinstance(input_ids, dict): + current_input_ids = defaultdict(list) + for k in input_ids.keys(): + current_input_ids[k].extend([input_ids[k][i] for i in indices]) + current_input_ids[k] = torch.cat(current_input_ids[k], dim=batch_dim) + + current_input_others = {"positional_inputs": input_others["positional_inputs"]} + for key in input_others.keys(): + if "positional_inputs" in key: + continue + if key in share_cache_keys: + # Shared keys are stored once (not per-sample), often wrapped in a + # 1-element list by the caching hook. Unwrap so the model receives + # the raw value (e.g. (cos, sin) tuple, not [(cos, sin)]). + val = input_others[key] + if isinstance(val, list) and len(val) == 1: + current_input_others[key] = val[0] + else: + current_input_others[key] = val + elif not isinstance(input_others[key], (str, bool, type(None))): + current_input_others[key] = None + if input_others[key] is not None: + current_input_others[key] = [input_others[key][i] for i in indices] + if len(indices) == 1: + current_input_others[key] = current_input_others[key][0] + else: + try: + current_input_others[key] = torch.cat(current_input_others[key], dim=0) + except TypeError as err: + logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") + else: + current_input_others[key] = input_others[key] + + return current_input_ids, current_input_others + + @torch.no_grad() + def _get_diffusion_block_outputs( + self, + block: torch.nn.Module, + input_ids: Union[torch.Tensor, dict], + input_others, + bs: int, + device: Union[str, torch.device], + cache_device: Union[str, torch.device], + save_output: bool = True, + ): + """Compute block outputs for diffusion models. + + Uses ``self.DIFFUSION_OUTPUT_CONFIGS`` to map block class names to their + output keys. Subclasses override ``DIFFUSION_OUTPUT_CONFIGS`` to add + support for new diffusion architectures. + """ + output = defaultdict(list) + output_config = self.DIFFUSION_OUTPUT_CONFIGS.get(block.__class__.__name__, ["hidden_states"]) + if isinstance(input_ids, dict): + nsamples = len(input_ids["hidden_states"]) + else: + nsamples = len(input_ids) + + for i in range(0, nsamples, bs): + end_index = min(nsamples, i + bs) + indices = torch.arange(i, end_index).to(torch.long) + tmp_input_ids, tmp_input_others = self._sampling_inputs( + input_ids, + input_others, + indices, + self.seqlen, + self.batch_dim, + share_cache_keys=self.model_context.shared_cache_keys, + ) + if isinstance(tmp_input_ids, dict): + hidden_states = tmp_input_ids.pop("hidden_states") + tmp_input_others.update(tmp_input_ids) + tmp_input_ids = hidden_states + + tmp_output = block_forward( + block, + tmp_input_ids, + tmp_input_others, + self.model_context.amp, + self.model_context.amp_dtype, + device, + None, + ) + if isinstance(tmp_output, torch.Tensor): + tmp_output = [tmp_output] + assert len(output_config) == len(tmp_output) + tmp_output = dict(zip(output_config, tmp_output)) + + if save_output: + for name, out in tmp_output.items(): + if self.batch_size == 1: + output[name].append(out.to(cache_device)) + else: + output[name].extend(list(torch.split(out.to(cache_device), 1, dim=self.batch_dim))) + self.compress_context.clear_memory() + + return output diff --git a/auto_round/algorithms/quantization/config.py b/auto_round/algorithms/quantization/config.py new file mode 100644 index 000000000..d99219e51 --- /dev/null +++ b/auto_round/algorithms/quantization/config.py @@ -0,0 +1,227 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from enum import Enum +from typing import ClassVar, Union + +from auto_round.algorithms.alg_config import AlgConfig +from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG +from auto_round.logger import logger +from auto_round.schemes import QuantizationScheme + + +class BackendDataType(str, Enum): + STANDARD_FP = "fp" + MX_FP = "mx_fp" + NV_FP = "nv_fp" + FP8_STATIC = "fp8_static" + FP8 = "fp8" + + +@dataclass(kw_only=True) +class QuantizationConfig: + _alg_cls: ClassVar[str] = None + + # quantization args + bits: int = None + group_size: int = None + sym: bool = None + data_type: str = None + act_bits: int = None + act_group_size: int = None + act_sym: bool = None + act_data_type: str = None + act_dynamic: bool = None + super_bits: int = None + super_group_size: int = None + scale_dtype: str = None + ignore_layers: str = "" + quant_lm_head: bool = False + to_quant_block_names: Union[str, list, None] = None + + def __post_init__(self): + # Run block-wise validation early (at construction time, before model loading). + # Scheme resolution is deferred to BaseCompressor.post_init() via SchemeMixin. + # Guard with None checks in case the user hasn't explicitly set data_type/bits + # (they will be resolved from scheme by the compressor before use). + if self.group_size is not None and isinstance(self.group_size, (tuple, list)): + if not ( + self.data_type is not None + and self.bits is not None + and self.data_type.startswith("fp") + and self.bits == 8 + ): + raise ValueError( + "Block-wise quantization (tuple group_size) only supports fp8 weight quantization, " + f"but got data_type='{self.data_type}', bits={self.bits}." + ) + if ( + self.act_dynamic is not None + and self.act_data_type is not None + and self.act_bits is not None + and not (self.act_dynamic and self.act_data_type.startswith("fp") and self.act_bits == 8) + ): + raise NotImplementedError( + "Block-wise fp8 weight quantization only supports dynamic fp8 activation quantization. " + f"Got act_dynamic={self.act_dynamic}, act_data_type='{self.act_data_type}', " + f"act_bits={self.act_bits}." + ) + if self.act_group_size is not None and isinstance(self.act_group_size, (tuple, list)): + raise ValueError( + "`act_group_size` must be -1 (per channel), 0 (per-tensor), or a positive integer, not a tuple." + ) + + @staticmethod + def _is_valid_group_size(gs) -> bool: + """Return True if gs is a valid group_size value. + + Accepts -1 (per-channel), 0 (per-tensor), a positive integer, + or a tuple/list of such values (e.g. (128, 128) for block-wise FP8). + """ + if isinstance(gs, (tuple, list)): + return all(QuantizationConfig._is_valid_group_size(g) for g in gs) + return gs == -1 or gs >= 0 + + def check_config(self) -> None: + """Checks if the configurations are valid. + + Raises: + ValueError, TypeError: If any of the configurations are invalid. + """ + if self.bits <= 0: + raise ValueError("`bits` must be positive") + if self.act_bits <= 0: + raise ValueError("`act_bits` must be positive") + if not self._is_valid_group_size(self.group_size): + raise ValueError( + "`group_size` must be -1 (per channel), 0 (per-tensor), a positive integer, " + "or a tuple thereof (e.g. (128, 128) for block-wise quantization)" + ) + if isinstance(self.act_group_size, (tuple, list)): + raise ValueError( + "`act_group_size` must be -1 (per channel), 0 (per-tensor), or a positive integer, not a tuple." + ) + if not self._is_valid_group_size(self.act_group_size): + raise ValueError( + "`act_group_size` must be -1 (per channel), 0 (per-tensor), a positive integer, " "or a tuple thereof" + ) + # Block-wise (tuple group_size) is only valid for fp8 weight quantization + if isinstance(self.group_size, (tuple, list)): + if not (self.data_type.startswith("fp") and self.bits == 8): + raise ValueError( + "Block-wise quantization (tuple group_size) only supports fp8 weight quantization, " + f"but got data_type='{self.data_type}', bits={self.bits}." + ) + if not (self.act_dynamic and self.act_data_type.startswith("fp") and self.act_bits == 8): + raise NotImplementedError( + "Block-wise fp8 weight quantization only supports dynamic fp8 activation quantization. " + f"Got act_dynamic={self.act_dynamic}, act_data_type='{self.act_data_type}', " + f"act_bits={self.act_bits}." + ) + # Reset the default value of super_bits and super_group_size + if self.data_type.endswith("_dq"): + gguf_config = GGUF_INNER_CONFIG[f"gguf:q{self.bits}_k"] + self.super_bits = gguf_config.get("super_bits", None) if self.super_bits is None else self.super_bits + self.super_group_size = ( + gguf_config.get("super_group_size", None) if self.super_group_size is None else self.super_group_size + ) + + if ( + self.is_act_quantize + and (not self.is_act_nv_fp or "static_gs" not in self.act_data_type) + and not self.is_act_mx_fp + and not self.is_dynamic_wint8aint8 + and not self.is_static_afp8 + ): + logger.warning( + "activation quantization is an experimental feature with limited support and a complex API. " + "And please save the quantized model to fake format as real deployment is not supported currently" + ) + # For block-wise group_size (tuple), skip the scalar-only warnings + scalar_gs = self.group_size if not isinstance(self.group_size, (tuple, list)) else None + if self.is_mx_fp and scalar_gs != 32: + logger.warning("dtype mx_fp should only support group_size of 32 in real deployment") + if self.is_nv_fp and scalar_gs != 16: + logger.warning("dtype nv_fp should only support group_size of 16 in real deployment") + + @property + def is_act_quantize(self): + return self.act_bits is not None and self.act_bits <= 8 + + @property + def is_nv_fp(self): + return self.data_type is not None and BackendDataType.NV_FP in self.data_type + + @property + def is_act_nv_fp(self): + return self.act_data_type is not None and BackendDataType.NV_FP in self.act_data_type + + @property + def is_mx_fp(self): + return self.data_type is not None and BackendDataType.MX_FP in self.data_type + + @property + def is_act_mx_fp(self): + return self.act_data_type is not None and BackendDataType.MX_FP in self.act_data_type + + @property + def is_dynamic_wint8aint8(self): + if self.act_dynamic: + return True + if self.act_data_type is not None and self.data_type is not None: + if ("int8" in self.act_data_type or ("int" in self.act_data_type and self.act_bits == 8)) and ( + "int8" in self.data_type or ("int" in self.data_type and self.bits == 8) + ): + return True + return False + + @property + def is_standard_fp(self): + return ( + self.data_type is not None + and BackendDataType.STANDARD_FP in self.data_type + and not self.is_mx_fp + and not self.is_nv_fp + ) + + @property + def is_act_standard_fp(self): + return ( + self.act_data_type is not None + and BackendDataType.STANDARD_FP in self.act_data_type + and not self.is_act_mx_fp + and not self.is_act_nv_fp + ) + + @property + def is_static_afp8(self): + return self.act_data_type is not None and BackendDataType.FP8_STATIC in self.act_data_type + + @property + def is_static_wfp8afp8(self): + return self.data_type is not None and BackendDataType.FP8_STATIC in self.data_type and self.is_static_afp8 + + @property + def is_wfp8afp8(self): + if self.act_data_type is None or self.data_type is None: + return False + if ( + ("fp8" in self.act_data_type or ("fp" in self.act_data_type and self.act_bits == 8)) + and ("fp8" in self.data_type or ("fp" in self.data_type and self.bits == 8)) + and self.is_act_standard_fp + and self.is_standard_fp + ): + return True + else: + return False diff --git a/auto_round/algorithms/quantization/rtn/__init__.py b/auto_round/algorithms/quantization/rtn/__init__.py new file mode 100644 index 000000000..14a492441 --- /dev/null +++ b/auto_round/algorithms/quantization/rtn/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/auto_round/algorithms/quantization/rtn/config.py b/auto_round/algorithms/quantization/rtn/config.py new file mode 100644 index 000000000..6afc41b0f --- /dev/null +++ b/auto_round/algorithms/quantization/rtn/config.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_round.algorithms.quantization.config import QuantizationConfig +from auto_round.data_type import QUANT_FUNC_WITH_DTYPE +from auto_round.logger import logger + + +class RTNConfig(QuantizationConfig): + _alg_cls = "RTNQuantizer" + + def __init__( + self, + *, + disable_opt_rtn: bool = None, + # for opt-rtn + batch_size: int = 8, + **kwargs, + ): + # pop before super().__init__ so it doesn't leak into QuantizationConfig as an unknown kwarg + enable_opt_rtn = kwargs.pop("enable_opt_rtn", None) + super().__init__(**kwargs) + + self.batch_size = batch_size + + # Some helpers + self.infer_bs_coeff = 1 + self.batch_dim = None + + if enable_opt_rtn: + disable_opt_rtn = False + self.orig_disable_opt_rtn = disable_opt_rtn + + if disable_opt_rtn is None: + if self.bits and self.bits >= 8 and self.act_bits and self.act_bits >= 8 and self.data_type == "int": + logger.warning("`disable_opt_rtn` is turned on for W8A16/W8A8 quantization to improve efficiency.") + disable_opt_rtn = True + if disable_opt_rtn is None: + logger.info( + "`enable_opt_rtn` is turned on, set `--disable_opt_rtn` for higher speed at the cost of accuracy." + ) + disable_opt_rtn = False + self.disable_opt_rtn = disable_opt_rtn + if not self.disable_opt_rtn: + self._alg_cls = "OptimizedRTNQuantizer" diff --git a/auto_round/algorithms/quantization/rtn/quantizer.py b/auto_round/algorithms/quantization/rtn/quantizer.py new file mode 100644 index 000000000..864ca40fe --- /dev/null +++ b/auto_round/algorithms/quantization/rtn/quantizer.py @@ -0,0 +1,253 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import traceback +from collections import defaultdict +from typing import Any, Callable, Optional, Union + +import accelerate +import torch + +from auto_round.algorithms.quantization.base import BaseQuantizers +from auto_round.algorithms.quantization.rtn.config import RTNConfig +from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer +from auto_round.compressors_new.shard_writer import ShardWriter +from auto_round.compressors_new.utils import ( + IndexSampler, + block_forward, + check_need_act_calibration, + check_skippable_keywords, + collect_best_params, + get_shared_keys, + immediate_pack, + infer_bits_by_data_type, + init_cache, + reset_params, + set_layer_config, +) +from auto_round.data_type.utils import update_block_global_scale_if_needed +from auto_round.logger import logger +from auto_round.utils import ( + check_to_quantized, + convert_module_to_hp_if_necessary, + get_lm_head_name, + get_module, + htcore, + is_auto_device_mapping, + is_hpex_available, + memory_monitor, + set_amax_for_all_moe_layers, + set_module, +) +from auto_round.utils.device import ( + clear_memory_if_reached_threshold, + get_major_device, + parse_available_devices, + set_auto_device_map_for_block_with_tuning, + set_non_auto_device_map, +) +from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block + + +class RTNQuantizer(BaseQuantizers): + + def __init__(self, config: RTNConfig): + BaseQuantizers.__init__(self, config) + + def quantize_block( + self, block: torch.nn.Module, input_ids=None, input_others=None, reference_output=None, **kwargs + ) -> dict: + """Apply zero-shot RTN quantization to a block. + + Pure-algorithm entry point. Infrastructure (materialize, shard writing, + device cleanup) is handled by the Compressor before/after this call. + + Args: + block: Module already materialized and placed on the correct device. + input_ids: Unused for zero-shot RTN (accepted for interface consistency). + input_others: Unused for zero-shot RTN. + reference_output: Unused for zero-shot RTN. + + Returns: + dict: Empty dict (zero-shot RTN has no tunable parameters to return). + """ + if ( + self.config.is_act_nv_fp + or self.config.is_static_afp8 + or (self.config.is_wfp8afp8 and not self.config.act_dynamic) + ): + # For FP8 static / NVFP paths, expert input scales are derived during + # layer quantization from the current act_max. Unify MoE input-proj + # act_max values before quantizing each expert so exported input_scale + # stays aligned across experts. + set_amax_for_all_moe_layers(block, attr_name="act_max") + + for _name, m in block.named_modules(): + if hasattr(m, "global_name") and check_to_quantized(m): + self.quantize_layer(m.global_name) + return {} + + def quantize_layer(self, name: str, dtype: torch.dtype = None) -> None: + """Quantizes a layer using RTN (Round-To-Nearest) if available. + + This function attempts to quantize a layer by switching its data type to a + `rtn_*` version if supported, then wraps and unwraps the module to apply + quantization. If GPU memory is insufficient, it falls back to CPU. + + If packing is enabled (`immediate_packing`), the function will also export + the quantized layer to the appropriate backend format. + + Args: + name (str): Name of the layer to quantize. + + Raises: + RuntimeError: If quantization fails for reasons unrelated to memory. + """ + + m = get_module(self.model, name) + if dtype is not None: + m = m.to(dtype) + + m = convert_module_to_hp_if_necessary(m, self.model_context.amp_dtype, self.compress_context.device) + set_module(self.model, name, m) + tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.compress_context.device + # Step 1: let gguf merge layers or rename module first and we will handle the RTN is gguf specific logic + if ( + self.compress_context.is_immediate_packing + and self.compress_context.formats[0].is_gguf() + and not getattr(self.config, "disable_opt_rtn", False) + ): + m = m.to(tuning_device) + m.scale = None + m.zp = None + else: + try: + disable_opt_rtn = bool(getattr(self.config, "disable_opt_rtn", False)) + if ( + not disable_opt_rtn + and self.config.orig_disable_opt_rtn is None + and self.model_context.is_moe_model + and "expert" in m.global_name + and "shared_expert" not in m.global_name + and self.config.super_bits is None # GGUF still uses the optimized RTN for MoE layers + ): + disable_opt_rtn = True + logger.warning_once( + "MoE layer detected: optimized RTN is disabled for efficiency. " + "Use `--enable_opt_rtn` to force-enable it for MoE layers." + ) + m = m.to(tuning_device) + m = WrapperLinear( + m, + device=tuning_device, + enable_minmax_tuning=False, + enable_norm_bias_tuning=False, + enable_round_tuning=False, + enable_torch_compile=self.compress_context.enable_torch_compile, + disable_opt_rtn=disable_opt_rtn, + enable_rtn=True, + ) + m = m.unwrapper({}) + except torch.OutOfMemoryError: + cuda_error_msg = traceback.format_exc() + m = m.orig_layer if hasattr(m, "orig_layer") else m + try: + logger.error(cuda_error_msg) + logger.warning("falling back to CPU.") + m.to("cpu") + m = WrapperLinear( + m, + enable_minmax_tuning=False, + enable_norm_bias_tuning=False, + enable_round_tuning=False, + enable_torch_compile=self.compress_context.enable_torch_compile, + enable_rtn=True, + ) + m = m.unwrapper({}) + except Exception as e: + raise + + set_module(self.model, name, m) + self._immediate_pack_and_save_module(name) + + def _immediate_pack_and_save_module(self, module_name): + shard_writer = ShardWriter.get_shard_writer() + to_cpu = self.compress_context.low_gpu_mem_usage + module = get_module(self.model, module_name) + if self.compress_context.is_immediate_packing: # For gguf, packing conducts on block level + immediate_pack(module_name, self.layer_config) + if to_cpu: + module = module.to("cpu") + packed_module = get_module(self.model, module_name) + set_module(self.model, module_name, packed_module.to("cpu")) + else: + if to_cpu: + module = module.to("cpu") + set_module(self.model, module_name, module) + if self.compress_context.is_immediate_saving: + module = get_module(self.model, module_name) + module.to("cpu") + shard_writer.write(module, module_name, False) + # Free RAM immediately: the data is now in the shard-writer buffer + # (and will be flushed to disk). Keeping it also in the model tree + # causes linear RAM growth for large models. + module.to("meta") + + +class OptimizedRTNQuantizer(RTNQuantizer): + + def __init__(self, config: RTNConfig): + BaseQuantizers.__init__(self, config) + self.batch_size = config.batch_size + self.batch_dim = config.batch_dim + self.data_type = config.data_type + self.group_size = config.group_size + self.infer_bs_coeff = config.infer_bs_coeff + + self.enable_alg_ext = True + + def quantize_layer_outside_block(self, *args, **kwargs): + return self.quantize_layer(*args, **kwargs) + + def quantize_block( + self, block: torch.nn.Module, input_ids=None, input_others=None, reference_output=None, **kwargs + ): + """Apply imatrix-informed RTN quantization to a block. + + Pure-algorithm entry point. All infrastructure (device placement, + act-max hook registration, imatrix collection, cleanup) is handled + by the Compressor before calling this method. + + Args: + block: Module already placed on the correct device(s) with act_max + attributes populated by the Compressor's hook pass. + input_ids: Unused for optimized RTN; accepted for interface consistency. + input_others: Unused for optimized RTN. + reference_output: Unused for optimized RTN. + """ + update_block_global_scale_if_needed(block, self.data_type, self.group_size) + if ( + self.config.is_act_nv_fp + or self.config.is_static_afp8 + or (self.config.is_wfp8afp8 and not self.config.act_dynamic) + ): + # enable moe experts act_max automatic generation for Linear + set_amax_for_all_moe_layers(block, attr_name="act_max") + # Normalize imatrix and quantize layers + for name, m in block.named_modules(): + if hasattr(m, "imatrix"): + m.imatrix /= m.imatrix_cnt + if hasattr(m, "global_name") and check_to_quantized(m): + self.quantize_layer_outside_block(m.global_name) + + # _get_block_outputs and _sampling_inputs are defined in BaseQuantizers and inherited. diff --git a/auto_round/algorithms/quantization/sign_round/__init__.py b/auto_round/algorithms/quantization/sign_round/__init__.py new file mode 100644 index 000000000..14a492441 --- /dev/null +++ b/auto_round/algorithms/quantization/sign_round/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/auto_round/algorithms/quantization/sign_round/config.py b/auto_round/algorithms/quantization/sign_round/config.py new file mode 100644 index 000000000..90b7d86b1 --- /dev/null +++ b/auto_round/algorithms/quantization/sign_round/config.py @@ -0,0 +1,111 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from auto_round.algorithms.quantization.config import QuantizationConfig +from auto_round.logger import logger + + +class SignRoundConfig(QuantizationConfig): + """ + + Args: + iters (int): Number of iterations (default is 200). + lr (float): The learning rate (default is 0.005). + minmax_lr (float): The learning rate for min-max tuning (default is None). + lr_scheduler: The learning rate scheduler to be used. + batch_size (int): Batch size for training (default is 8). + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning + """ + + _alg_cls = "SignRoundQuantizer" + + def __init__( + self, + *, + iters: int = 200, + lr: float = None, + minmax_lr: float = None, + lr_scheduler=None, + momentum: float = 0.0, + batch_size: int = 8, + nblocks: int = 1, + enable_minmax_tuning: bool = True, + enable_norm_bias_tuning: bool = False, + gradient_accumulate_steps: int = 1, + enable_alg_ext: bool = False, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + enable_quanted_input: bool = True, + optimizer: str = None, + enable_adam: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.iters = iters + if self.iters < 0: + logger.warning("`iters` must be non-negative, reset it to 200") + self.iters = 200 + + if not lr: + # TODO need to check 4 bits lr setting for auto-round-best, 3bits only validate on small models + if self.iters >= 1000 and self.bits is not None and self.bits <= 3: + self.lr = 2.0 / self.iters + logger.info("set the lr to 2.0/iters for better accuracy") + else: + self.lr = 1.0 / self.iters + else: + self.lr = lr + self.minmax_lr = minmax_lr or self.lr + self.lr_scheduler = lr_scheduler + + self.batch_size, self.gradient_accumulate_steps = batch_size, gradient_accumulate_steps + self.nblocks = nblocks + self.momentum = momentum + self.enable_alg_ext = enable_alg_ext + + # Some helpers + self.infer_bs_coeff = 1 + self.batch_dim = None + + self.enable_minmax_tuning = enable_minmax_tuning + self.enable_norm_bias_tuning = enable_norm_bias_tuning + if self.enable_norm_bias_tuning: + logger.warning("the `enable_norm_bias_tuning` feature is experimental and currently has limited support.") + self.not_use_best_mse = not_use_best_mse + self.dynamic_max_gap = dynamic_max_gap + self.enable_quanted_input = enable_quanted_input + self.optimizer = optimizer + self.enable_adam = enable_adam + + if self.enable_adam: + self._alg_cls = "AdamRoundQuantizer" + + def check_configs(self) -> None: + """Checks if the configurations are valid. + + Raises: + ValueError, TypeError: If any of the configurations are invalid. + """ + super().check_config() + + if self.batch_size <= 0: + raise ValueError("`batch_size` must be positive") + if self.iters < 0: + raise ValueError("`iters` must be non-negative") + if self.nblocks <= 0: + raise ValueError("`nblocks` must be positive") + if self.gradient_accumulate_steps <= 0: + raise ValueError("`gradient_accumulate_steps` must be positive") diff --git a/auto_round/algorithms/quantization/sign_round/quantizer.py b/auto_round/algorithms/quantization/sign_round/quantizer.py new file mode 100644 index 000000000..1f7b5b539 --- /dev/null +++ b/auto_round/algorithms/quantization/sign_round/quantizer.py @@ -0,0 +1,568 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from collections import defaultdict +from contextlib import nullcontext +from functools import partial +from typing import Any, Callable, Optional, Union + +import accelerate +import torch +from torch import autocast + +from auto_round.algorithms.quantization.base import BaseQuantizers +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig +from auto_round.algorithms.quantization.sign_round.sign_sgd import SignSGD +from auto_round.compressors_new.utils import ( + IndexSampler, + block_forward, + check_need_act_calibration, + collect_best_params, + immediate_pack, +) +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, update_fused_layer_global_scales +from auto_round.logger import logger +from auto_round.utils import ( + check_to_quantized, + compile_func, + convert_module_to_hp_if_necessary, + get_module, + htcore, + is_auto_device_mapping, + is_hpex_available, + memory_monitor, + mv_module_from_gpu, + set_amax_for_all_moe_layers, + to_device, +) +from auto_round.utils.device import ( + clear_memory_if_reached_threshold, + set_auto_device_map_for_block_with_tuning, +) +from auto_round.utils.distributed import setup_ddp_if_needed_ +from auto_round.wrapper import WrapperLinear, unwrapper_block, unwrapper_layer, wrapper_block + + +class SignRoundQuantizer(BaseQuantizers): + + def __init__(self, config: SignRoundConfig): + super().__init__(config) + self.attention_mask = [] + + self.iters = config.iters + self.lr = config.lr + self.minmax_lr = config.minmax_lr + self.lr_scheduler = config.lr_scheduler + self.batch_size = config.batch_size + self.batch_dim = config.batch_dim + self.momentum = config.momentum + self.infer_bs_coeff = config.infer_bs_coeff + self.enable_minmax_tuning = config.enable_minmax_tuning + self.enable_norm_bias_tuning = config.enable_norm_bias_tuning + self.gradient_accumulate_steps = config.gradient_accumulate_steps + self.enable_alg_ext = config.enable_alg_ext + self.not_use_best_mse = config.not_use_best_mse + self.enable_quanted_input = config.enable_quanted_input + self.dynamic_max_gap = config.dynamic_max_gap + + self.optimizer = self._get_optimizer(optimizer=config.optimizer) + self.wrapper_block = wrapper_block + + if self.enable_alg_ext: + try: + logger.info("using algorithm extension for quantization.") + from auto_round.alg_ext import wrapper_autoround + + wrapper_autoround(self) + except (ImportError, ModuleNotFoundError): + logger.error("algorithm extension import error, fallback to default mode") + + def _get_current_output(self, output: list[torch.Tensor], indices: list[int]) -> torch.Tensor: + if self.model_context.is_diffusion: + assert "hidden_states" in output + current_output = [output["hidden_states"][x] for x in indices] + current_output = torch.cat(current_output, dim=self.batch_dim) + return current_output + current_output = [output[x] for x in indices] + current_output = torch.cat(current_output, dim=self.batch_dim) + return current_output + + def _get_current_num_elm( + self, + input_ids: list[torch.Tensor], + indices: list[int], + ) -> int: + if self.model_context.is_diffusion: + current_input_ids = [input_ids["hidden_states"][i] for i in indices] + return sum(id.numel() for id in current_input_ids) + + current_input_ids = [input_ids[i] for i in indices] + return sum(id.numel() for id in current_input_ids) + + def _get_non_zero_cnt(self, tensor: list[torch.Tensor], indices: list[int]) -> int: + current_tensors = [tensor[i] for i in indices] + non_zero_cnt = 0 + for t in current_tensors: + non_zero_cnt += torch.count_nonzero(t).item() + return non_zero_cnt + + def _get_loss( + self, + output_q: torch.Tensor, + current_output: torch.Tensor, + indices: torch.Tensor, + mse_loss: Callable, + device: Union[str, torch.device] = "cpu", + ): + autocast_ctx = ( + nullcontext() + if self.model_context.amp + else autocast(device_type=str(device).split(":")[0], dtype=self.model_context.amp_dtype) + ) + if self.attention_mask: + tmp_attention_mask = [self.attention_mask[i] for i in indices] + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask.unsqueeze_(-1) + + with autocast_ctx: + loss = mse_loss( # pylint: disable=not-callable + (output_q * tmp_attention_mask).to(torch.float32), + (current_output * tmp_attention_mask).to(torch.float32), + ) + else: + with autocast_ctx: + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), current_output.to(torch.float32) + ) + + return loss + + def quantize_block( + self, + block: torch.nn.Module, + input_ids: Union[list[torch.Tensor], dict], + input_others: dict, + reference_output, + *, + loss_device: Union[str, torch.device], + mid_iter_mem_check: bool = False, + **kwargs, + ) -> dict: + """Apply the AutoRound optimization algorithm to a block. + + This is the pure-algorithm entry point. All infrastructure concerns + (device placement, act-max hook collection, reference-output caching, + DDP setup, memory cleanup, logging) are handled by the Compressor + before and after this call. + + Args: + block: Module already placed on the correct device(s). + input_ids: Calibration inputs (already on cache_device). + input_others: Additional inputs for the block's forward pass. + reference_output: FP reference outputs collected by the Compressor. + loss_device: Device on which to compute the MSE loss. + mid_iter_mem_check: Pre-evaluated by the Compressor as + ``low_gpu_mem_usage and card_0_in_high_risk``. When True, + triggers mid-iteration memory threshold checks to reduce + fragmentation on the primary GPU. + + Returns: + best_params: Best quantization parameters found during optimization. + Empty dict if no trainable parameters were found. + """ + device = self.compress_context.device + + quantized_layer_names, unquantized_layer_names = self.wrapper_block( + block, + self.enable_minmax_tuning, + self.enable_norm_bias_tuning, + enable_torch_compile=self.compress_context.enable_torch_compile, + device=device, + ) + if self.config.is_nv_fp: + for module in block.modules(): + update_fused_layer_global_scales(module) + round_params = [] + minmax_params = [] + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + for key in m.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(m.params[key]) + else: + round_params.append(m.params[key]) + + lr = torch.tensor(self.lr) + minmax_lr = torch.tensor(self.minmax_lr) + + extra_kwargs = {} if self.momentum is None else {"momentum": self.momentum} + + if self.enable_minmax_tuning: + params = [ + {"params": round_params}, + {"params": minmax_params, "lr": minmax_lr}, + ] + else: + params = round_params + + optimizer = self.optimizer( + params, + lr=lr, + weight_decay=0, + **extra_kwargs, + ) + + if len(round_params) + len(minmax_params) <= 0: + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block" + ) + logger.info(dump_info) + unwrapper_block(block, {}) + return {} + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + + if isinstance(input_ids, dict): # input_ids of Flux is dict + nsamples = len(input_ids["hidden_states"]) + else: + nsamples = len(input_ids) + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + num_elm = 1 + mse_reduction = "mean" + if self.gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + scaler = self._get_scaler() # pylint: disable=assignment-from-none + init_loss = None + best_params = {} + total_loss = 0 + global_batch_size = self.batch_size * self.gradient_accumulate_steps + global_batch_size = min(nsamples, global_batch_size) + # We assume the block input and output shape is same + if self.gradient_accumulate_steps != 1 and not self.attention_mask: + whole_indices = torch.arange(global_batch_size) + num_elm = self._get_current_num_elm(input_ids, whole_indices) + setup_ddp_if_needed_(self, block, self.compress_context.device_list) + index_sampler = IndexSampler(nsamples, global_batch_size) + batch_size = self.batch_size + for i in range(self.iters): + if self.enable_alg_ext and self.data_type.endswith("dq"): + for n, m in block.named_modules(): + m.cur_iter = i + total_loss = 0 + global_indices = index_sampler.next_batch() + if self.attention_mask: + num_elm = self._get_non_zero_cnt(self.attention_mask, global_indices) + + for batch_start in range(0, len(global_indices), batch_size): + indices = global_indices[batch_start : batch_start + batch_size] + current_output = self._get_current_output(reference_output, indices) + current_output = to_device(current_output, loss_device) + output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device) + loss = self._get_loss(output_q, current_output, indices, mse_loss, device) + num_elm = 1 if num_elm <= 0 else num_elm + total_loss += loss.item() / num_elm + + if mid_iter_mem_check: + # clear memory to avoid OOM due to memory fragmentation + clear_memory_if_reached_threshold(threshold=0.5, device_list=self.compress_context.device_list) + + self._scale_loss_and_backward(scaler, loss) + + if mid_iter_mem_check: + # clear memory to avoid OOM due to memory fragmentation + clear_memory_if_reached_threshold(threshold=0.8, device_list=self.compress_context.device_list) + + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(block, self.compress_context.cache_device) + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(block, self.compress_context.cache_device) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self._step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + if self.iters > 0: + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + ) + else: + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + "layers in the block" + ) + + self.compress_context.clear_memory() # clear cached memory during training + if len(unquantized_layer_names) != 0: + logger.info(f"Unquantized layers: {unquantized_layer_names}") + with torch.no_grad(): + unwrapper_block(block, best_params) + + if self.config.is_act_nv_fp: + # enable moe experts act_max automatic generation for WrapperWALayer + set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") + + logger.infoclean(dump_info) + return best_params + + def quantize_layer_outside_block( + self, layer_name: str, input_ids: torch.Tensor, q_inputs: torch.Tensor = None, device: str = "cpu", **kwargs + ): + """Quantize a specific layer of the model using the provided inputs. + + Args: + layer_name (str): The name of the layer to quantize. + inputs (torch.Tensor): Input data for quantization. + q_inputs (torch.Tensor, optional): Quantized input data. Defaults to None. + device (torch.device, optional): The device to use for quantization. Defaults to torch.device("cpu"). + + Returns: + None + """ + logger.info(f"quantizing layer {layer_name}") + layer = get_module(self.model, layer_name) + if hasattr(layer, "tuning_device"): + device = layer.tuning_device + + layer = layer.to(device) + for i in range(len(input_ids)): + input_ids[i] = input_ids[i].to(layer.weight.dtype) + if q_inputs is not None: + q_inputs[i] = q_inputs[i].to(layer.weight.dtype) + + static_kv_dtype = self.compress_context.static_kv_dtype + static_attention_dtype = self.compress_context.static_attention_dtype + if self.config.is_act_quantize and check_need_act_calibration( + self.config.act_dynamic, + self.config.act_data_type, + self.config.act_bits, + static_kv_dtype, + static_attention_dtype, + ): + tmp_inputs = q_inputs if q_inputs is not None else input_ids + hook_handles = self._register_act_max_hook(layer) + with torch.no_grad(): + for input in tmp_inputs: + layer(input) + for handle in hook_handles: + handle.remove() + + wrapper_linear = WrapperLinear( + layer, + enable_minmax_tuning=self.enable_minmax_tuning, + enable_torch_compile=self.compress_context.enable_torch_compile, + device=device, + ).to(device) + round_params = [] + minmax_params = [] + for key in wrapper_linear.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(wrapper_linear.params[key]) + else: + round_params.append(wrapper_linear.value) + if len(round_params) + len(minmax_params) <= 0: + dump_info = f"quantized {layer_name}" + logger.info(dump_info) + with torch.no_grad(): + unwrapper_layer(self.model, wrapper_linear, layer_name, {}) + mv_module_from_gpu(layer) + + lr = torch.tensor(self.lr) + minmax_lr = torch.tensor(self.minmax_lr) + if self.enable_minmax_tuning: + optimizer = self.optimizer( + [{"params": round_params}, {"params": minmax_params, "lr": minmax_lr}], lr=lr, weight_decay=0 + ) + else: + optimizer = self.optimizer(round_params, lr=lr, weight_decay=0) + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + nsamples = len(input_ids) + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + best_params = None + scaler = self._get_scaler() # pylint: disable=assignment-from-none + init_loss = None + gradient_accumulate_steps = self.batch_size # Force to low gpu + + total_loss = 0 + num_elm = 1 + mse_reduction = "mean" + if gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + batch_size = 1 # Force to low gpu + global_batch_size = self.batch_size * gradient_accumulate_steps + global_batch_size = min(nsamples, global_batch_size) + if gradient_accumulate_steps != 1 and not self.attention_mask: + whole_indices = torch.arange(global_batch_size) + if q_inputs is not None: + num_elm = self._get_current_num_elm(q_inputs, whole_indices) + else: + num_elm = self._get_current_num_elm(input_ids, whole_indices) + + index_sampler = IndexSampler(nsamples, global_batch_size) + + for i in range(self.iters): + total_loss = 0 + global_indices = index_sampler.next_batch() + if self.attention_mask: + num_elm = self._get_non_zero_cnt(self.attention_mask, global_indices) + + for batch_start in range(0, len(global_indices), batch_size): + indices = global_indices[batch_start : batch_start + batch_size] + if q_inputs is not None: + current_input = [q_inputs[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = [input_ids[i] for i in indices] + org_input = torch.cat(org_input, dim=0).to(device) + else: + current_input = [input_ids[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = current_input + with torch.no_grad(): + current_output = layer(org_input) + autocast_ctx = ( + nullcontext() + if not self.model_context.amp + else autocast(device_type=str(device).split(":")[0], dtype=self.model_context.amp_dtype) + ) + if self.attention_mask: + tmp_attention_mask = [self.attention_mask[i] for i in indices] + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask.unsqueeze_(-1) + + with autocast_ctx: + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable + (output_q * tmp_attention_mask).to(torch.float32), + (current_output * tmp_attention_mask).to(torch.float32), + ) + + else: + with autocast_ctx: + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), + current_output.to(torch.float32), # mul 1.0 will copy the output + ) + + num_elm = 1 if num_elm <= 0 else num_elm + total_loss += loss.item() / num_elm + + self._scale_loss_and_backward(scaler, loss) + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(wrapper_linear, self.compress_context.cache_device) + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(wrapper_linear, self.compress_context.cache_device) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self._step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + with torch.no_grad(): + unwrapper_layer(self.model, wrapper_linear, layer_name, best_params) + mv_module_from_gpu(layer) + dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + logger.info(dump_info) + + def _get_optimizer(self, optimizer: Any): + """Returns the specified optimizer. In SignRound, we fix the optimizer. + + Args: + optimizer: The optimizer to be used. + + Returns: + The specified optimizer. + """ + if optimizer is not None: + logger.warning_once( + "The optimizer setting in config will be ignored in AutoRound, using SignSGD as default." + ) + return SignSGD + + def _get_scaler(self): + """Returns scaler, in SignRound, no need to use scaler.""" + return None + + def _scale_loss_and_backward(self, scaler: Any, loss: torch.Tensor) -> torch.Tensor: + """Scales the loss and performs backward pass. + + Args: + scaler: The scaler to be used. + loss: The loss to be scaled. + + Returns: + The scaled loss. + """ + scale_loss = loss * 1000 + scale_loss.backward() + if is_hpex_available(): + htcore.mark_step() + return scale_loss + + def _step(self, scaler: Any, optimizer: Any, lr_schedule: Any): + """Performs a step in the optimization process. + + Args: + scaler: The scaler to be used. + optimizer: The optimizer for the step. + lr_schedule: The learning rate schedule. + + Returns: + None + """ + optimizer.step() + # for hpu + if is_hpex_available(): + htcore.mark_step() + optimizer.zero_grad() + lr_schedule.step() diff --git a/auto_round/sign_sgd.py b/auto_round/algorithms/quantization/sign_round/sign_sgd.py similarity index 100% rename from auto_round/sign_sgd.py rename to auto_round/algorithms/quantization/sign_round/sign_sgd.py diff --git a/auto_round/algorithms/transforms/__init__.py b/auto_round/algorithms/transforms/__init__.py new file mode 100644 index 000000000..6648cf6d6 --- /dev/null +++ b/auto_round/algorithms/transforms/__init__.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Weight/activation rotation algorithm package. + +This package houses all *pre-quantisation rotation/transform* algorithms – +mathematical operations applied to model weights or activations before the +quantisation step to improve numerical properties. + +Current algorithms +------------------ +* **hadamard** – Block-diagonal Hadamard rotations (QuaRot / SpinQuant style). + See :mod:`auto_round.algorithms.transforms.rotation`. + +Adding a new algorithm +----------------------- +1. Create ``algorithms/transforms//`` with ``config.py`` and ``apply.py``. +2. Subclass :class:`BaseRotationConfig` and :class:`BaseRotation`; register + with ``@BaseRotation.register("")``. +3. Re-export from this ``__init__.py``. + +Typical usage +------------- +>>> from auto_round.algorithms.transforms import apply_rotation +>>> model = apply_rotation(model, config={"hadamard_type": "random_hadamard"}) +""" + +from __future__ import annotations + +from typing import Any + +import torch + +from auto_round.algorithms.transforms.base import ( + BaseRotation, + BaseRotationConfig, + ROTATION_SUPPORTED_SCHEMES, + check_supported_schemes, +) +from auto_round.algorithms.transforms.rotation import ( + HadamardRotation, + apply_rotation_transform, + normalize_rotation_config as _normalize_hadamard_config, + RotationConfig, +) + +__all__ = [ + # Base interfaces + "BaseRotation", + "BaseRotationConfig", + "ROTATION_SUPPORTED_SCHEMES", + "check_supported_schemes", + # Config + "RotationConfig", + "HadamardRotation", + "apply_rotation_transform", + # Unified entry + "apply_rotation", + "normalize_rotation_config", +] + + +def normalize_rotation_config( + config: Any, +) -> BaseRotationConfig | None: + """Normalise any supported config form to the canonical :class:`BaseRotationConfig` subclass. + + Dispatches by inspecting the ``algorithm`` field (or missing field for + legacy dicts that only carry Hadamard keys). + + Args: + config: One of: ``None``, :class:`RotationConfig`, a ``dict`` with + an ``"algorithm"`` key, or a plain Hadamard shorthand string. + + Returns: + The appropriate :class:`BaseRotationConfig` subclass, or ``None`` + when *config* is ``None`` / empty. + """ + if config is None: + return None + + if isinstance(config, BaseRotationConfig): + return config + + if isinstance(config, dict): + alg = config.get("algorithm", "hadamard") + if alg == "hadamard": + return RotationConfig.model_validate(config) + raise ValueError( + f"Unknown rotation algorithm: {alg!r}. " f"Registered algorithms: {sorted(BaseRotation._REGISTRY)}" + ) + + if isinstance(config, str): + # String shorthand → treat as Hadamard config. + return RotationConfig.model_validate(_normalize_hadamard_config(config)) + + raise TypeError( + f"Unsupported rotation config type: {type(config).__name__}. " + "Expected None, dict, str, or a BaseRotationConfig subclass." + ) + + +def apply_rotation( + model: torch.nn.Module, + config: Any, + data_type: str = "mx_fp", + **kwargs: Any, +) -> torch.nn.Module: + """Apply a rotation/transform algorithm to *model*. + + This is the single, algorithm-agnostic entry point. The correct + :class:`BaseRotation` subclass is selected automatically from *config*. + + Args: + model: Model to transform (modified in-place). + config: Rotation configuration. Accepts: + + * ``None`` – no-op, returns *model* unmodified. + * :class:`RotationConfig` or compatible ``dict``/``str``. + * Any :class:`BaseRotationConfig` subclass. + + data_type: Quantization data type (e.g. ``"mx_fp"``). + **kwargs: Forwarded to :meth:`BaseRotation.apply_to_model`. + + Returns: + The transformed model. + """ + if config is None: + return model + + normalised = normalize_rotation_config(config) + if normalised is None: + return model + + rotation = BaseRotation.from_config(normalised) + return rotation.apply_to_model(model, data_type=data_type, **kwargs) diff --git a/auto_round/algorithms/transforms/base.py b/auto_round/algorithms/transforms/base.py new file mode 100644 index 000000000..236e64347 --- /dev/null +++ b/auto_round/algorithms/transforms/base.py @@ -0,0 +1,166 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base classes and utilities for weight/activation rotation algorithms. + +All rotation algorithms (Hadamard, SpinQuant, QuaRot, …) must subclass +``BaseRotation`` and declare a corresponding ``BaseRotationConfig``. + +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +import torch + +# --------------------------------------------------------------------------- +# Config base +# --------------------------------------------------------------------------- + + +@dataclass +class BaseRotationConfig: + """Minimal base for all rotation algorithm configs. + + Every concrete config subclass should be a ``dataclass`` so it is + trivially serialisable / comparable. + """ + + #: Human-readable algorithm name, must be unique across all subclasses. + algorithm: str = "base" + + +# --------------------------------------------------------------------------- +# Algorithm base +# --------------------------------------------------------------------------- + + +class BaseRotation(ABC): + """Unified interface for all weight/activation rotation transforms. + + Concrete subclasses implement :meth:`apply_to_model` for their specific + mathematical transform (Hadamard rotation, random rotation, …). + + Example + ------- + >>> from auto_round.algorithms.transforms import apply_rotation + >>> model = apply_rotation(model, config={"algorithm": "hadamard", ...}) + """ + + # Registry populated by subclasses via ``BaseRotation.register``. + _REGISTRY: dict[str, type["BaseRotation"]] = {} + + def __init__(self, config: BaseRotationConfig) -> None: + self.config = config + + # ------------------------------------------------------------------ + # Abstract interface + # ------------------------------------------------------------------ + + @abstractmethod + def apply_to_model( + self, + model: torch.nn.Module, + data_type: str = "mx_fp", + **kwargs: Any, + ) -> torch.nn.Module: + """Apply this rotation to *model* and return the (possibly mutated) model. + + Args: + model: The model to transform. + data_type: Quantization data type (e.g. ``"mx_fp"``). + **kwargs: Algorithm-specific extra arguments. + + Returns: + The transformed model. + """ + + # ------------------------------------------------------------------ + # Factory + # ------------------------------------------------------------------ + + @classmethod + def register(cls, algorithm_name: str): + """Class decorator to register a ``BaseRotation`` subclass. + + Usage:: + + @BaseRotation.register("hadamard") + class HadamardRotation(BaseRotation): + ... + """ + + def _decorator(subclass: type[BaseRotation]) -> type[BaseRotation]: + cls._REGISTRY[algorithm_name] = subclass + return subclass + + return _decorator + + @classmethod + def from_config(cls, config: BaseRotationConfig) -> "BaseRotation": + """Instantiate the correct ``BaseRotation`` subclass for *config*. + + The algorithm is looked up by ``config.algorithm`` in the registry. + Sub-packages are imported lazily on first access so that optional + dependencies (e.g. ``pydantic``) are not required unless actually used. + """ + # Lazy-load all sub-packages to populate the registry. + _ensure_registry_populated() + + name = getattr(config, "algorithm", None) + if name not in cls._REGISTRY: + raise ValueError(f"No rotation algorithm registered under {name!r}. " f"Available: {sorted(cls._REGISTRY)}") + return cls._REGISTRY[name](config) + + +# --------------------------------------------------------------------------- +# Scheme compatibility check +# --------------------------------------------------------------------------- + +#: Quantization schemes that support (and require) rotation transforms. +ROTATION_SUPPORTED_SCHEMES: list[str] = ["MXFP8", "MXFP4", "NVFP4"] + + +def check_supported_schemes(scheme: str) -> None: + """Raise ``ValueError`` if *scheme* does not support rotation transforms.""" + if scheme not in ROTATION_SUPPORTED_SCHEMES: + raise ValueError( + f"Rotation transforms are not supported for scheme {scheme!r}. " + f"Currently supported schemes: {ROTATION_SUPPORTED_SCHEMES}" + ) + + +# --------------------------------------------------------------------------- +# Lazy registry population +# --------------------------------------------------------------------------- + +_registry_populated = False + + +def _ensure_registry_populated() -> None: + """Import all known sub-packages so their ``@BaseRotation.register`` calls run.""" + global _registry_populated + if _registry_populated: + return + # Import each sub-package here. Add new entries as more algorithms land. + import importlib + + for sub in ("hadamard",): + try: + importlib.import_module(f"auto_round.algorithms.transforms.{sub}") + except ImportError: + pass + _registry_populated = True diff --git a/auto_round/algorithms/transforms/rotation/__init__.py b/auto_round/algorithms/transforms/rotation/__init__.py new file mode 100644 index 000000000..640cf06ab --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hadamard rotation sub-package for ``algorithms/transforms``.""" + +from auto_round.algorithms.transforms.rotation.apply import ( + HadamardRotation, + apply_rotation_transform, +) +from auto_round.algorithms.transforms.rotation.config import ( + RotationConfig, + normalize_rotation_config, +) +from auto_round.algorithms.transforms.rotation.transforms import ( + HADAMARDS, + HadamardTransform, + RandomHadamardTransform, + build_hadamard_transform, +) + +__all__ = [ + # Algorithm class + "HadamardRotation", + # Config + "RotationConfig", + "normalize_rotation_config", + # Transform modules + "HadamardTransform", + "RandomHadamardTransform", + "HADAMARDS", + "build_hadamard_transform", + # One-shot convenience + "apply_rotation_transform", +] diff --git a/auto_round/algorithms/transforms/rotation/apply.py b/auto_round/algorithms/transforms/rotation/apply.py new file mode 100644 index 000000000..62ceb02b5 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/apply.py @@ -0,0 +1,309 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hadamard rotation – concrete ``BaseRotation`` implementation. + +Public entry points +------------------- +* :class:`HadamardRotation` – the stateful algorithm object. +* :func:`apply_rotation_transform` – convenience one-shot function. +""" + +from __future__ import annotations + +from typing import Any + +import torch +import tqdm + +from auto_round.algorithms.transforms.base import BaseRotation +from auto_round.algorithms.transforms.rotation.config import RotationConfig, normalize_rotation_config +from auto_round.algorithms.transforms.rotation.transforms import build_hadamard_transform +from auto_round.compressors.utils import is_nv_fp +from auto_round.experimental.qmodules.base import QModuleBase + +__all__ = ["HadamardRotation", "apply_rotation_transform"] + + +def _triton_available(data_type: str = "mx_fp") -> bool: + """Best-effort check for whether Triton kernel path can be used.""" + if is_nv_fp(data_type): + return False + try: + import triton # noqa: F401 # pylint: disable=E0401 + + if not torch.cuda.is_available(): + return False + from auto_round.algorithms.transforms.rotation.utils.triton.mxfp4 import ( # noqa: F401 + mxfp4_forward_kernel_wrapper, + ) + + return True + except Exception: + return False + + +@BaseRotation.register("hadamard") +class HadamardRotation(BaseRotation): + """Hadamard rotation algorithm. + + Registered under ``"hadamard"`` in the + :class:`~auto_round.algorithms.transforms.base.BaseRotation` registry. + + Typical usage (via the top-level helper):: + + from auto_round.algorithms.transforms import apply_rotation + model = apply_rotation(model, config={"hadamard_type": "random_hadamard"}) + + Or directly:: + + from auto_round.algorithms.transforms.rotation import apply_rotation_transform + model = apply_rotation_transform(model, config=RotationConfig(), need_calibration=True) + """ + + def __init__(self, config: RotationConfig) -> None: + super().__init__(config) + + @classmethod + def from_config(cls, config: dict | RotationConfig) -> "HadamardRotation": + """Build a :class:`HadamardRotation` from a raw dict or :class:`RotationConfig`.""" + if isinstance(config, dict): + config = RotationConfig.model_validate(config) + return cls(config) + + def apply_to_model( + self, + model: torch.nn.Module, + location: str = "weight", + use_tqdm: bool = True, + desc: str | None = None, + data_type: str = "mx_fp", + **kwargs: Any, + ) -> torch.nn.Module: + """Apply the Hadamard rotation to *model*. + + Args: + model: Target model; modified in-place. + location: ``"weight"`` (eager, fused into weights) or + ``"input"`` (activation-side, via forward hook). + use_tqdm: Show a progress bar while iterating modules. + desc: Custom progress-bar description. + data_type: Quantization data type (e.g. ``"mx_fp"``). + **kwargs: Reserved for future use. + + Returns: + The mutated *model* with ``model.rotation_config`` set to the + normalised :class:`RotationConfig` dict. + """ + cfg = self.config + + # Dispatch by backend. The transform backend (triton-fused per-Linear) + # is implemented below; the inplace (QuaRot) backend is delegated to + # :mod:`auto_round.algorithms.transforms.rotation.inplace`. + from auto_round.algorithms.transforms.rotation.dispatcher import resolve_hadamard_backend + + backend = resolve_hadamard_backend(cfg, data_type) + if backend == "inplace": + import auto_round.envs as envs + from auto_round.algorithms.transforms.rotation.inplace import apply_rotation_transform as _inplace_apply + + # Resolve fuse flag: explicit > env var > default(False). + fuse_online_to_weight = cfg.fuse_online_to_weight + if cfg.fuse_online_to_weight is not None: + fuse_online_to_weight = bool(cfg.fuse_online_to_weight) + elif envs.AR_FUSE_ONLINE_ROTATION: + fuse_online_to_weight = bool(envs.AR_FUSE_ONLINE_ROTATION) + + bs = cfg.block_size + group_size = bs if (bs is not None and bs > 0) else None + + compute_device = kwargs.get("compute_device") + model, _hooks = _inplace_apply( + model, + group_size=group_size, + allow_online_rotation=cfg.allow_online_rotation, + rotation_matrix=cfg.hadamard_type, + fuse_online_to_weight=fuse_online_to_weight, + compute_device=compute_device, + ) + setattr(model, "rotation_config", cfg.model_dump()) + return model + + # backend == "transform": original per-Linear triton-fused path. + # Collect target modules. + target_types = (torch.nn.Linear, QModuleBase) + + modules = [(name, module) for name, module in model.named_modules() if isinstance(module, target_types)] + + _desc = desc or f"Applying {cfg.hadamard_type} transforms" + for name, module in tqdm.tqdm(modules, desc=_desc, disable=not use_tqdm): + if "lm_head" in name: + continue + _apply_to_module(model, module, cfg, location, data_type) + + # Store config on model for serialisation / downstream inspection. + setattr(model, "rotation_config", cfg.model_dump()) + return model + + +# --------------------------------------------------------------------------- +# Module-level application helper +# --------------------------------------------------------------------------- + + +def _apply_to_module( + model: torch.nn.Module, + module: torch.nn.Module, + config: RotationConfig, + location: str, + data_type: str = "mx_fp", +) -> None: + """Apply the configured Hadamard transform to a single *module*.""" + if location == "input": + _apply_input_transform(module, config, data_type) + + elif location == "weight": + _apply_weight_transform(module, config) + + else: + raise NotImplementedError(f"Unsupported transform location: {location!r}") + + +def _apply_input_transform(module: torch.nn.Module, config: RotationConfig, data_type: str = "mx_fp") -> None: + """Register a forward pre-hook that applies the Hadamard to the input activation.""" + from auto_round.algorithms.transforms.rotation.utils.matrix import multihead_matmul + + inp_transform = build_hadamard_transform( + **config.model_dump(), + location="input", + inverse=True, + device="cpu", + precision=module.dtype if hasattr(module, "dtype") else None, + ) + + if config.hadamard_type != "random_hadamard": + hadamard_weight = inp_transform.weight + else: + hadamard_weight = None + + if _triton_available(data_type): + from auto_round.algorithms.transforms.rotation.utils.triton.mxfp4 import mxfp4_forward_kernel_wrapper + + def _input_hook(self, args): + x = args[0] + orig_shape = x.shape + orig_dtype = x.dtype + x_flat = x.contiguous().flatten(end_dim=-2) + w = hadamard_weight.to(orig_dtype) if hadamard_weight is not None else self.hadamard_matrix.T.to(orig_dtype) + qdq_input, _ = mxfp4_forward_kernel_wrapper(x_flat, w) + return qdq_input.reshape(orig_shape).to(orig_dtype) + + module.pre_dequantized_input = True + module.register_forward_pre_hook(_input_hook, prepend=True) + else: + + def _input_hook(self, args): + x = args[0] + ori_shape = x.shape + orig_dtype = x.dtype + if hadamard_weight is not None: + x = x.view(-1, hadamard_weight.shape[0]) + return multihead_matmul(x, hadamard_weight.to(x.device).to(orig_dtype)).view(ori_shape).to(orig_dtype) + else: + x = x.view(-1, self.hadamard_matrix.shape[0]) + return multihead_matmul(x, self.hadamard_matrix.T.to(orig_dtype)).view(ori_shape).to(orig_dtype) + + module.pre_dequantized_input = False + module.register_forward_pre_hook(_input_hook, prepend=True) + + +def _apply_weight_transform( + module: torch.nn.Module, + config: RotationConfig, +) -> None: + """Fuse or patch the Hadamard rotation into the weight of *module*.""" + from auto_round.algorithms.transforms.rotation.patch import ( + patch_quantlinear, + patch_wrapperlinear_to_apply_transform, + patch_wrapperwalayer_forward_to_apply_transform, + ) + + assert hasattr(module, "weight"), "Weight transform requires module to have a 'weight' attribute" + + w_transform = build_hadamard_transform( + **config.model_dump(), + location="weight", + device=module.weight.device, + ) + + # For random Hadamard, save the matrix as a submodule for serialisation. + if config.hadamard_type == "random_hadamard": + from auto_round.algorithms.transforms.rotation.patch import patch_quantlinear as _patch_ql + + _patch_ql(w_transform) + + # Patch WrapperLinear and WrapperWALayer so the transform is applied + # during calibration tuning. + inp_transform = build_hadamard_transform( + **config.model_dump(), + location="input", + inverse=True, + device=module.weight.device, + precision=module.weight.dtype, + ) + + patch_wrapperlinear_to_apply_transform(w_transform, inp_transform) + patch_wrapperwalayer_forward_to_apply_transform(inp_transform) + + +# --------------------------------------------------------------------------- +# Convenience one-shot function +# --------------------------------------------------------------------------- + + +def apply_rotation_transform( + model: torch.nn.Module, + config: str | dict | RotationConfig | None, + location: str = "weight", + use_tqdm: bool = True, + desc: str | None = None, + data_type: str = "mx_fp", +) -> torch.nn.Module: + """Apply a Hadamard rotation to *model*. + + This is the main public entry point when you only want Hadamard (rather + than the polymorphic :func:`~auto_round.algorithms.transforms.apply_rotation`). + + Args: + model: Target model. + config: One of: :class:`RotationConfig`, ``dict``, ``str`` + shorthand, or ``None`` (no-op). + location: ``"weight"`` or ``"input"``. + use_tqdm: Show progress bar. + desc: Custom progress-bar label. + data_type: Quantization data type (e.g. ``"mx_fp"``). + + Returns: + The transformed model. + """ + normalised = normalize_rotation_config(config) + if not normalised: + return model + rotation = HadamardRotation.from_config(normalised) + return rotation.apply_to_model( + model, + location=location, + use_tqdm=use_tqdm, + desc=desc, + data_type=data_type, + ) diff --git a/auto_round/algorithms/transforms/rotation/config.py b/auto_round/algorithms/transforms/rotation/config.py new file mode 100644 index 000000000..2a8c50697 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/config.py @@ -0,0 +1,188 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotation/transform configuration (canonical, unified). + +This module is the **single source of truth** for the ``RotationConfig`` +schema. The legacy location +``auto_round.experimental.transform.rotation_config`` re-exports from here. + +Two implementation backends share this one schema (method B): + +* ``backend="inplace"`` – QuaRot-style residual-stream rotation, implemented + under :mod:`auto_round.experimental.rotation_inplace`. Works for any + weight/activation dtype and can optionally fuse the online Hadamard into + weights (``fuse_online_to_weight=True``). + +* ``backend="transform"`` – Per-Linear weight + activation Hadamard with a + fused triton kernel, implemented under + :mod:`auto_round.algorithms.transforms.rotation.apply`. Supports only + MXFP4 / NVFP4 and cannot fuse online to weight. + +* ``backend="auto"`` – dispatcher picks inplace when a fused online rotation + is requested, transform when the data_type is MX/NV-FP, inplace otherwise. +""" + +from __future__ import annotations + +from typing import Any, Optional + +from pydantic import BaseModel, Field, field_validator + +from auto_round.algorithms.transforms.base import BaseRotationConfig +from auto_round.compressors.utils import is_mx_fp, is_nv_fp +from auto_round.utils import logger + +__all__ = [ + "RotationConfig", + "normalize_rotation_config", + "to_dict_rotation_config", + "dump_group_size_to_rotation_config", +] + + +# Supported Hadamard transform types (also used by HadamardTransform registry). +HADAMARD_TYPES: frozenset[str] = frozenset({"hadamard", "random_hadamard", "quarot_hadamard"}) +_SUPPORTED_BACKENDS: frozenset[str] = frozenset({"auto", "inplace", "transform"}) + + +class RotationConfig(BaseModel, BaseRotationConfig): + """Unified configuration for Hadamard rotation/transform applied to a model. + + See the module docstring for a description of the three backends. + + Notes: + * ``block_size`` is the group/block size for grouped Hadamard. + For ``backend="inplace"`` it is forwarded as ``group_size`` + (``None`` / ``-1`` means full-dimension Hadamard). + """ + + # Registry key consumed by BaseRotation.from_config (kept for API parity + # with other BaseRotationConfig subclasses). + algorithm: str = Field(default="hadamard", frozen=True) + + # ---- shared ---- + backend: str = Field(default="auto") + block_size: Optional[int] = Field(default=None) + hadamard_type: str = Field(default="hadamard") + + # ---- inplace-only ---- + fuse_online_to_weight: Optional[bool] = Field(default=None) + allow_online_rotation: bool = Field(default=True) + + # for random hadamard (transform path) + random_seed: bool = Field(default=False, exclude=True) + + model_config = {"arbitrary_types_allowed": True} + + @field_validator("backend") + @classmethod + def _validate_backend(cls, v: str) -> str: + if v not in _SUPPORTED_BACKENDS: + raise ValueError(f"Unsupported backend: {v}. Supported values: {sorted(_SUPPORTED_BACKENDS)}") + return v + + @field_validator("hadamard_type") + @classmethod + def _validate_hadamard_type(cls, v: str) -> str: + if v not in HADAMARD_TYPES: + raise ValueError(f"Unsupported hadamard_type: {v!r}. Supported values: {sorted(HADAMARD_TYPES)}") + return v + + +# --------------------------------------------------------------------------- +# Helpers (free functions – match the old experimental/utils.py API) +# --------------------------------------------------------------------------- + + +def to_dict_rotation_config(rotation_config: str | dict | RotationConfig | None) -> dict[str, Any]: + """Convert any supported config form to a plain ``dict`` (no data-type logic). + + Accepts: + * ``None`` → ``{}`` + * :class:`RotationConfig` → ``model_dump()`` + * ``dict`` → shallow-copied + * ``str`` → ``{"hadamard_type": key}`` (``"default"`` ⇒ plain default) + """ + if rotation_config is None: + return {} + if isinstance(rotation_config, str): + key = rotation_config.strip() + if not key: + return {} + if key == "default": + return {"hadamard_type": "hadamard"} + return {"hadamard_type": key} + if isinstance(rotation_config, RotationConfig): + return rotation_config.model_dump() + return dict(rotation_config) + + +def dump_group_size_to_rotation_config(rotation_config: str | dict | RotationConfig, group_size: int) -> dict[str, Any]: + """Return *rotation_config* as a dict with ``block_size`` populated from *group_size* (if unset).""" + rotation_dict = to_dict_rotation_config(rotation_config) + if rotation_dict.get("block_size", None) is None: + rotation_dict["block_size"] = group_size + return rotation_dict + + +def normalize_rotation_config( + rotation_config: str | dict | RotationConfig | None, + data_type: str = "mx_fp", +) -> dict[str, Any]: + """Normalise *rotation_config* to a validated ``dict`` ready for ``RotationConfig(**)``. + + Behaviour: + * ``None`` → ``{}`` + * If ``block_size`` is not set: + - ``mx_fp`` → default 32 + - ``nv_fp`` → default 16 + - other data types → emit a warning (no default) + * If ``block_size`` mismatches the data-type recommendation, emit a warning. + + Raises: + ValueError: If the resulting config is invalid. + """ + + def _apply_data_type_block_size(cfg_dict: dict[str, Any], block_size_explicitly_set: bool) -> dict[str, Any]: + block_size = cfg_dict.get("block_size") + + if not block_size_explicitly_set or block_size is None: + if is_mx_fp(data_type): + cfg_dict["block_size"] = 32 + elif is_nv_fp(data_type): + cfg_dict["block_size"] = 16 + logger.warning("block_size is not set for data_type 'nv_fp'; defaulting to 16.") + else: + logger.warning( + f"block_size is not set and cannot be inferred for data_type {data_type!r}; " + "please set block_size explicitly in rotation_config if needed." + ) + else: + if is_mx_fp(data_type) and block_size != 32: + logger.warning(f"data_type is 'mx_fp' but block_size={block_size}; recommended value is 32.") + elif is_nv_fp(data_type) and block_size != 16: + logger.warning(f"data_type is 'nv_fp' but block_size={block_size}; recommended value is 16.") + + return cfg_dict + + if rotation_config is None: + return {} + + rotation_dict = to_dict_rotation_config(rotation_config) + block_size_explicitly_set = "block_size" in rotation_dict + cfg_dict = _apply_data_type_block_size(rotation_dict, block_size_explicitly_set) + try: + return RotationConfig.model_validate(cfg_dict).model_dump() + except Exception as exc: + raise ValueError(f"Invalid RotationConfig: {exc}") from exc diff --git a/auto_round/algorithms/transforms/rotation/dispatcher.py b/auto_round/algorithms/transforms/rotation/dispatcher.py new file mode 100644 index 000000000..ecae6652f --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/dispatcher.py @@ -0,0 +1,153 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Unified entry point for Hadamard rotation/transform. + +Two backend implementations exist: + +* ``inplace`` – :mod:`auto_round.algorithms.transforms.rotation.inplace` + QuaRot-style residual-stream rotation. Works for any weight/activation + dtype. Optionally fuses the online Hadamard into weights + (``fuse_online_to_weight=True``). +* ``transform`` – :mod:`auto_round.experimental.transform` + Per-Linear weight + activation Hadamard with a fused triton kernel. + Only supports MXFP4 / NVFP4 and **cannot** fuse online to weight. + +Routing is controlled by :class:`RotationConfig.backend`: + + "inplace" -> always inplace + "transform" -> always transform (validates dtype + no-fuse) + "auto" -> if user asked to fuse -> inplace + elif data_type is mx_fp / nv_fp -> transform + else -> inplace +""" + +from __future__ import annotations + +from typing import Any, Union + +import torch + +import auto_round.envs as envs +from auto_round.algorithms.transforms.rotation.config import RotationConfig, normalize_rotation_config +from auto_round.compressors.utils import is_mx_fp, is_nv_fp +from auto_round.utils import logger + +__all__ = ["apply_hadamard_rotation", "resolve_hadamard_backend"] + + +def _to_config( + rotation_config: Union[str, dict, RotationConfig, None], + data_type: str, +) -> RotationConfig: + """Normalise *rotation_config* and return a :class:`RotationConfig` instance.""" + cfg_dict = normalize_rotation_config(rotation_config, data_type) + if isinstance(cfg_dict, RotationConfig): + return cfg_dict + return RotationConfig.model_validate(cfg_dict or {}) + + +def resolve_hadamard_backend(config: RotationConfig, data_type: str) -> str: + """Resolve the actual backend (``"inplace"`` / ``"transform"``) from config.""" + requested = config.backend + fuse_requested = bool(config.fuse_online_to_weight) + allow_online_rotation: bool = config.allow_online_rotation + + if requested == "inplace": + return "inplace" + + transform_backend_name = "transform" + if requested == "transform": + if fuse_requested: + raise ValueError( + f"backend='{transform_backend_name}' does not support fuse_online_to_weight=True. " + "Use backend='inplace' (or backend='auto' with fuse_online_to_weight=True) instead." + ) + if not (is_mx_fp(data_type) or is_nv_fp(data_type)): + raise ValueError( + f"backend='{transform_backend_name}' only supports MXFP4 / NVFP4 (got data_type={data_type!r}). " + "Use backend='inplace' or backend='auto' for other dtypes." + ) + if not allow_online_rotation: + raise ValueError(f"backend='{transform_backend_name}' only supports `allow_online_rotation`=True") + + return "transform" + + # backend == "auto" + if fuse_requested: + return "inplace" + if is_mx_fp(data_type) or is_nv_fp(data_type): + return "transform" + return "inplace" + + +def apply_hadamard_rotation( + model: torch.nn.Module, + rotation_config: Union[str, dict, RotationConfig, None], + data_type: str, + compute_device: torch.device | str = None, +) -> (torch.nn.Module, Any): + """Apply Hadamard rotation/transform to *model*, dispatching by backend. + + Args: + model: Target model. + rotation_config: ``str`` / ``dict`` / :class:`RotationConfig` / ``None``. + See :class:`RotationConfig` for fields. + data_type: Quantization data type (e.g. ``"mx_fp"``, ``"nv_fp"``, + ``"int"``, ``"fp"``). + compute_device: Device for inplace-backend computation. Ignored by + the transform backend. + + Returns: + The same model (for chaining); also stored on ``model.rotation_config``. + """ + config = _to_config(rotation_config, data_type) + backend = resolve_hadamard_backend(config, data_type) + + # Resolve fuse flag: explicit > env var > default(True) + fuse_online_to_weight = config.fuse_online_to_weight + if config.fuse_online_to_weight is not None: + fuse_online_to_weight = bool(config.fuse_online_to_weight) + elif envs.AR_FUSE_ONLINE_ROTATION: + fuse_online_to_weight = bool(envs.AR_FUSE_ONLINE_ROTATION) + + logger.info( + f"Applying Hadamard (backend={backend}, " + f"data_type={data_type}, fuse_online_to_weight={fuse_online_to_weight if backend == 'inplace' else False})." + ) + + if backend == "inplace": + logger.warning("this backend does not support real exporting, please export the model to fake format") + from auto_round.algorithms.transforms.rotation.inplace import apply_rotation_transform + + # block_size -> group_size (None / -1 / 0 means full-dimension) + bs = config.block_size + group_size = bs if (bs is not None and bs > 0) else None + + model, hooks = apply_rotation_transform( + model, + group_size=group_size, + allow_online_rotation=config.allow_online_rotation, + rotation_matrix=config.hadamard_type, + fuse_online_to_weight=fuse_online_to_weight, + compute_device=compute_device, + ) + # Stash for downstream (export / serialization). Plain dict so JSON + # serialization (HF save_pretrained -> config.json) round-trips. + setattr(model, "rotation_config", config.model_dump() if hasattr(config, "model_dump") else config) + return model, hooks + + elif backend == "transform": + supported_hadamard_types = ("hadamard", "random_hadamard") + if config.hadamard_type not in supported_hadamard_types: + raise ValueError("this backend only supports hadamard or random_hadamard") + from auto_round.algorithms.transforms.rotation.apply import apply_rotation_transform + + return apply_rotation_transform(model, config, data_type=data_type) + else: + raise ValueError(f"Unsupported Hadamard backend {backend!r}") diff --git a/auto_round/algorithms/transforms/rotation/inplace/__init__.py b/auto_round/algorithms/transforms/rotation/inplace/__init__.py new file mode 100644 index 000000000..9bef07da9 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/inplace/__init__.py @@ -0,0 +1,12 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 +"""Inplace (QuaRot-style) Hadamard rotation backend. + +Canonical home of the residual-stream Hadamard rotation implementation +(formerly under :mod:`auto_round.experimental.rotation_inplace`). +""" + +from auto_round.algorithms.transforms.rotation.inplace.apply import apply_rotation_transform # noqa: F401 +from auto_round.algorithms.transforms.rotation.inplace.hooks import clear_random_hadamard_cache # noqa: F401 + +__all__ = ["apply_rotation_transform", "clear_random_hadamard_cache"] diff --git a/auto_round/algorithms/transforms/rotation/inplace/apply.py b/auto_round/algorithms/transforms/rotation/inplace/apply.py new file mode 100644 index 000000000..3177d6fa9 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/inplace/apply.py @@ -0,0 +1,882 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +"""Hadamard inplace rotation — public API and rotation primitives. + +Supports LLaMA-2, LLaMA-3, Qwen-3 (and any model with the same layout). +The entry point is :func:`apply_hadamard_rotation`. +""" + +import gc +import typing +from typing import Dict, Union + +import torch +import tqdm + +from auto_round.algorithms.transforms.rotation.inplace.hooks import ( + CrossHeadOnlineHadamardHook, + FullOnlineHadamardHook, + GroupOnlineHadamardHook, + _get_custom_had, + _normalize_rotation_matrix, + _resolve_compute_device, + _rotate_embedding_grouped, + _rotate_linear_grouped, + apply_cross_head_had_to_linear, + apply_exact_had_to_linear, + deterministic_hadamard_matrix, + get_hadK, + get_or_create_random_hadamard, +) +from auto_round.algorithms.transforms.rotation.inplace.model_config import ( + MAPPING_REGISTRY, + RotationMapping, + _resolve, + infer_mapping_from_model, +) + +# --------------------------------------------------------------------------- +# Low-level primitives (model-agnostic via RotationMapping) +# --------------------------------------------------------------------------- + + +def _fuse_ln_linear( + layernorm: torch.nn.Module, + linear_layers: typing.Iterable[torch.nn.Linear], +) -> None: + """Fuse the linear operations in LayerNorm into adjacent linear blocks.""" + for linear in linear_layers: + linear_dtype = linear.weight.dtype + dev = linear.weight.device + + W_ = linear.weight.data.double() + ln_weight = layernorm.weight.double().to(dev) + linear.weight.data = (W_ * ln_weight).to(linear_dtype) + + if hasattr(layernorm, "bias") and layernorm.bias is not None: + if linear.bias is None: + linear.bias = torch.nn.Parameter(torch.zeros(linear.out_features, dtype=torch.float64, device=dev)) + ln_bias = layernorm.bias.double().to(dev) + linear.bias.data = linear.bias.data.double() + torch.matmul(W_, ln_bias) + linear.bias.data = linear.bias.data.to(linear_dtype) + + +def _reset_ln_params(layernorm: torch.nn.Module) -> None: + """Reset LayerNorm to identity: weight=1, bias=0.""" + layernorm.weight.data.fill_(1.0) + if hasattr(layernorm, "bias") and layernorm.bias is not None: + layernorm.bias.data.fill_(0.0) + + +def _rotate_linear_by_Q(module: torch.nn.Linear, Q: torch.Tensor, side: str, compute_device=None) -> None: + """Apply rotation *Q* to a Linear layer's weight (and bias if present). + + Args: + side: ``'input'`` → W = W @ Q (rotate input side) + ``'output'`` → W = Q^T @ W (rotate output side) + compute_device: Device to run computation on. If None, auto-detects GPU. + """ + dtype = module.weight.data.dtype + dev = module.weight.data.device + cdev = _resolve_compute_device(compute_device) + W_ = module.weight.data.to(device=cdev, dtype=torch.float64) + Q_ = Q.to(device=cdev) + if side == "input": + new_W = torch.matmul(W_, Q_).to(device=dev, dtype=dtype) + else: + new_W = torch.matmul(Q_.T, W_).to(device=dev, dtype=dtype) + # Release fp64 copy before assigning back so peak memory ≈ 1× weight + 1× rotated. + del W_ + module.weight.data = new_W + if side == "output" and module.bias is not None: + b = module.bias.data.to(device=cdev, dtype=torch.float64) + new_b = torch.matmul(Q_.T, b).to(device=dev, dtype=dtype) + del b + module.bias.data = new_b + del Q_ + + +def _untie_word_embeddings(model, mapping: RotationMapping) -> None: + """Break tied weights between lm_head and embedding if they share the same tensor.""" + embedding = _resolve(model, mapping.embedding) + lm_head = _resolve(model, mapping.lm_head) + + if lm_head.weight.data_ptr() != embedding.weight.data_ptr(): + return + + lm_head.weight = torch.nn.Parameter(lm_head.weight.data.clone()) + if hasattr(model.config, "tie_word_embeddings"): + model.config.tie_word_embeddings = False + + +def _uses_layernorm_with_mean(model, mapping: RotationMapping) -> bool: + """Check whether the model uses standard LayerNorm (which subtracts mean).""" + layers = _resolve(model, mapping.layers_attr) + first_ln = _resolve(layers[0], mapping.attn_input_ln) + return isinstance(first_ln, torch.nn.LayerNorm) + + +def _bake_mean_into_linear(linear: torch.nn.Linear) -> None: + """Subtract column-wise mean from a Linear layer's weight (and mean from bias).""" + linear_dtype = linear.weight.dtype + W_ = linear.weight.data.double() + linear.weight.data = (W_ - W_.mean(dim=-2, keepdim=True)).to(linear_dtype) + if linear.bias is not None: + b_ = linear.bias.data.double() + linear.bias.data = (b_ - b_.mean()).to(linear_dtype) + + +def _subtract_embedding_mean(model, mapping: RotationMapping) -> None: + """Subtract per-row mean from the embedding weight matrix.""" + W = _resolve(model, mapping.embedding) + dtype = W.weight.data.dtype + W_ = W.weight.data.to(dtype=torch.float64) + W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(dtype=dtype) + + if mapping.positional_embedding is not None: + P = _resolve(model, mapping.positional_embedding) + p_dtype = P.weight.data.dtype + P_ = P.weight.data.to(dtype=torch.float64) + P.weight.data = (P_ - P_.mean(dim=-1, keepdim=True)).to(dtype=p_dtype) + + +class _RMSNorm(torch.nn.Module): + """RMS Normalization (no mean subtraction).""" + + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.register_buffer("weight", torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + return x / rms * self.weight + + +def _replace_layernorms_with_rmsnorm(model) -> None: + """Replace all ``nn.LayerNorm`` modules with ``_RMSNorm``.""" + replacements = [] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.LayerNorm): + replacements.append((name, module)) + + for name, module in replacements: + parts = name.rsplit(".", 1) + if len(parts) == 2: + parent = _resolve(model, parts[0]) + attr = parts[1] + else: + parent = model + attr = parts[0] + rms = _RMSNorm(module.normalized_shape[0], eps=module.eps) + rms = rms.to(device=module.weight.device, dtype=module.weight.dtype) + setattr(parent, attr, rms) + + +# --------------------------------------------------------------------------- +# High-level steps driven by RotationMapping +# --------------------------------------------------------------------------- + + +def _fuse_layer_norms(model, mapping: RotationMapping) -> None: + """Fuse all LayerNorm parameters into adjacent Linear layers.""" + layers = _resolve(model, mapping.layers_attr) + + for layer in layers: + mlp_ln = _resolve(layer, mapping.mlp_input_ln) + mlp_linears = [_resolve(layer, p) for p in mapping.mlp_in] + _fuse_ln_linear(mlp_ln, mlp_linears) + _reset_ln_params(mlp_ln) + + attn_ln = _resolve(layer, mapping.attn_input_ln) + attn_linears = [ + _resolve(layer, mapping.attn_q), + _resolve(layer, mapping.attn_k), + _resolve(layer, mapping.attn_v), + ] + _fuse_ln_linear(attn_ln, attn_linears) + _reset_ln_params(attn_ln) + + pre_head_ln = _resolve(model, mapping.pre_head_ln) + lm_head = _resolve(model, mapping.lm_head) + _fuse_ln_linear(pre_head_ln, [lm_head]) + _reset_ln_params(pre_head_ln) + + +# --------------------------------------------------------------------------- +# Unified weight rotation (full or grouped) +# --------------------------------------------------------------------------- + + +@torch.inference_mode() +def _rotate_weights( + model, + mapping: RotationMapping, + use_fast_had: bool = True, + group_size: int = None, + compute_device: torch.device = None, + had_dict: dict = None, + preset: str = None, + fuse_online_to_weight: bool = True, +) -> None: + """Apply Hadamard rotation to all weights. + + Args: + group_size: ``None`` → full Hadamard rotation. + ``int`` → block-diagonal rotation with this block size. + compute_device: Device to run Hadamard computation on (e.g. ``"cuda:0"``). + Weights are moved there temporarily and moved back afterwards. + If ``None``, auto-detects GPU availability. + allow_online_rotation: If ``True`` (default), apply extra input-side + Hadamard rotations on ``down_proj`` and the OV pair (``v_proj`` + output + ``o_proj`` input) that require compensating online hooks + at inference time. If ``False``, skip those extra rotations so + that **no** online hooks are needed. + had_dict: Normalized ``dict[int, Tensor]`` of custom Hadamard matrices + (keyed by dimension). Only used in grouped mode. + preset: Rotation preset name (``"quarot_hadamard"``, ``"hadamard"``, + ``"random_hadamard"``, or ``None``). + + * ``"quarot_hadamard"``: fusable (residual-stream) rotations use + ``fast_hadamard_transform`` / random Hadamard; non-fusable + (online-paired) rotations and their weight-side counterparts use + deterministic ``get_hadK``/``matmul_hadU`` so that the online + hook at inference produces the exact same transform. + * ``"hadamard"``: all rotations use deterministic ``get_hadK`` / + ``matmul_hadU``. Full-mode Q is a deterministic Hadamard matrix. + * ``"random_hadamard"``: all rotations use random Hadamard matrices + from the global cache (``get_or_create_random_hadamard``). + Same dimension → same matrix everywhere. + * ``None``: same behaviour as ``"hadamard"`` (built-in butterfly). + """ + compute_device = _resolve_compute_device(compute_device) + config = model.config + hidden_size = getattr(config, mapping.hidden_size_attr) + intermediate_size = getattr(config, mapping.intermediate_size_attr) + num_heads = getattr(config, mapping.num_heads_attr) + head_dim = mapping.attn_head_dim or (hidden_size // num_heads) + + is_grouped = group_size is not None and group_size > 0 + desc = f"Rotating (group_size={group_size})" if is_grouped else "Rotating" + + # ----- Resolve per-operation Hadamard sources ----- + fused_fast = use_fast_had + online_fast = False + if preset == "random_hadamard": + fused_fast = False + + # -- Matrix resolution -- + had_matrix, _found = _get_custom_had(had_dict, group_size) if is_grouped else (None, False) + + online_had_matrix = had_matrix + if preset == "random_hadamard" and had_matrix is None: + had_matrix = get_or_create_random_hadamard(group_size if is_grouped else hidden_size, compute_device) + online_had_matrix = had_matrix + if preset == "quarot_hadamard" and is_grouped: + online_had_matrix = None # force deterministic for online-paired + + # -- Helper: look up cached random matrix for online-paired ops -- + def _online_had(dim): + """Return cached random matrix for *dim* under random_hadamard, else None.""" + if preset == "random_hadamard": + return get_or_create_random_hadamard(dim, compute_device) + return None + + if is_grouped: + assert hidden_size % group_size == 0, f"group_size={group_size} must divide hidden_size={hidden_size}" + assert ( + intermediate_size % group_size == 0 + ), f"group_size={group_size} must divide intermediate_size={intermediate_size}" + + # --- Full mode: build Hadamard matrix Q --- + Q = None + if not is_grouped: + if preset == "hadamard": + Q = deterministic_hadamard_matrix(hidden_size, compute_device) + else: + # "random_hadamard", "quarot_hadamard", None — same shape → same matrix + Q = get_or_create_random_hadamard(hidden_size, compute_device) + + # ---- Top-level: embedding / lm_head ---- + # When fuse_online_to_weight=False, skip embedding and lm_head rotation: + # each layer is self-contained (weight rotation + online hook cancel out). + if fuse_online_to_weight: + embedding = _resolve(model, mapping.embedding) + if is_grouped: + _rotate_embedding_grouped( + embedding, group_size, use_fast_had=fused_fast, compute_device=compute_device, had_matrix=had_matrix + ) + else: + dtype = embedding.weight.data.dtype + dev = embedding.weight.data.device + cdev = compute_device + W_ = embedding.weight.data.to(device=cdev, dtype=torch.float64) + new_W = torch.matmul(W_, Q.to(cdev)).to(device=dev, dtype=dtype) + del W_ + embedding.weight.data = new_W + + if mapping.positional_embedding is not None: + pos_emb = _resolve(model, mapping.positional_embedding) + if is_grouped: + _rotate_embedding_grouped( + pos_emb, group_size, use_fast_had=fused_fast, compute_device=compute_device, had_matrix=had_matrix + ) + else: + pos_dtype = pos_emb.weight.data.dtype + pos_dev = pos_emb.weight.data.device + cdev = compute_device + P_ = pos_emb.weight.data.to(device=cdev, dtype=torch.float64) + new_P = torch.matmul(P_, Q.to(cdev)).to(device=pos_dev, dtype=pos_dtype) + del P_ + pos_emb.weight.data = new_P + + # ---- Top-level: lm_head ---- + lm_head = _resolve(model, mapping.lm_head) + if is_grouped: + _rotate_linear_grouped( + lm_head, + group_size, + side="input", + use_fast_had=fused_fast, + compute_device=compute_device, + had_matrix=had_matrix, + ) + else: + _rotate_linear_by_Q(lm_head, Q, side="input", compute_device=compute_device) + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # ---- Per-layer rotation ---- + layers = _resolve(model, mapping.layers_attr) + for layer in tqdm.tqdm(layers, unit="layer", desc=desc): + if fuse_online_to_weight: + # ---- fuse mode: QuaRot-style residual stream rotation ---- + # Q/K/V: only residual Q on input (no online Had stacking, no hook). + # When Q == online Had (e.g. preset="hadamard"), Q @ Q = I cancels + # the rotation entirely, destroying quantization benefit. + # gate/up: only residual Q on input (no online Had stacking, no hook). + # down_proj: residual Q^T on output + online Had on input (+ hook). + # v_proj/o_proj: per-head/cross-head Had below (+ hook on o_proj). + for attr in (mapping.attn_q, mapping.attn_k, mapping.attn_v): + mod = _resolve(layer, attr) + if is_grouped: + _rotate_linear_grouped( + mod, + group_size, + side="input", + use_fast_had=fused_fast, + compute_device=compute_device, + had_matrix=had_matrix, + ) + else: + _rotate_linear_by_Q(mod, Q, side="input", compute_device=compute_device) + + # o_proj: residual stream output rotation + if is_grouped: + _rotate_linear_grouped( + _resolve(layer, mapping.attn_o), + group_size, + side="output", + use_fast_had=fused_fast, + compute_device=compute_device, + had_matrix=had_matrix, + ) + else: + _rotate_linear_by_Q(_resolve(layer, mapping.attn_o), Q, side="output", compute_device=compute_device) + + # gate/up: only residual Q on input + for attr in mapping.mlp_in: + mod = _resolve(layer, attr) + if is_grouped: + _rotate_linear_grouped( + mod, + group_size, + side="input", + use_fast_had=fused_fast, + compute_device=compute_device, + had_matrix=had_matrix, + ) + else: + _rotate_linear_by_Q(mod, Q, side="input", compute_device=compute_device) + + # down_proj: residual output + online input Had + down_proj = _resolve(layer, mapping.mlp_out) + if is_grouped: + _rotate_linear_grouped( + down_proj, + group_size, + side="output", + use_fast_had=fused_fast, + compute_device=compute_device, + had_matrix=had_matrix, + ) + _rotate_linear_grouped( + down_proj, + group_size, + side="input", + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=online_had_matrix, + ) + else: + _rotate_linear_by_Q(down_proj, Q, side="output", compute_device=compute_device) + apply_exact_had_to_linear( + down_proj, + had_dim=-1, + output=False, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=_online_had(intermediate_size), + ) + + # OV projection: v_proj per-head output + o_proj full/cross-head input + v_proj = _resolve(layer, mapping.attn_v) + o_proj = _resolve(layer, mapping.attn_o) + if is_grouped: + pass + else: + online_head_had = _online_had(head_dim) + apply_exact_had_to_linear( + v_proj, + had_dim=head_dim, + output=True, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=online_head_had, + ) + if preset == "random_hadamard": + apply_exact_had_to_linear( + o_proj, + had_dim=head_dim, + output=False, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=online_head_had, + ) + apply_cross_head_had_to_linear( + o_proj, + num_heads, + head_dim, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=_online_had(num_heads), + ) + else: + apply_exact_had_to_linear( + o_proj, + had_dim=-1, + output=False, + use_fast_had=online_fast, + compute_device=compute_device, + ) + + else: + # ---- unfused mode: no residual rotation, only input-side Had ---- + # Each layer gets Had fused on input side + compensating hook → equivalent. + # No embedding/lm_head rotation. No self-cancelling pair. + # v_proj treated same as Q/K (input Had only, no per-head/cross-head). + + # Q/K/V: input-side Had on hidden_size + for attr in (mapping.attn_q, mapping.attn_k, mapping.attn_v): + mod = _resolve(layer, attr) + if is_grouped: + _rotate_linear_grouped( + mod, + group_size, + side="input", + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=online_had_matrix, + ) + else: + apply_exact_had_to_linear( + mod, + had_dim=-1, + output=False, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=_online_had(hidden_size), + ) + + # o_proj: input-side Had on hidden_size (full Had, not cross-head) + o_proj = _resolve(layer, mapping.attn_o) + if is_grouped: + _rotate_linear_grouped( + o_proj, + group_size, + side="input", + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=online_had_matrix, + ) + else: + apply_exact_had_to_linear( + o_proj, + had_dim=-1, + output=False, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=_online_had(hidden_size), + ) + + # gate/up: input-side Had on hidden_size + for attr in mapping.mlp_in: + mod = _resolve(layer, attr) + if is_grouped: + _rotate_linear_grouped( + mod, + group_size, + side="input", + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=online_had_matrix, + ) + else: + apply_exact_had_to_linear( + mod, + had_dim=-1, + output=False, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=_online_had(hidden_size), + ) + + # down_proj: input-side Had on intermediate_size + down_proj = _resolve(layer, mapping.mlp_out) + if is_grouped: + _rotate_linear_grouped( + down_proj, + group_size, + side="input", + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=online_had_matrix, + ) + else: + apply_exact_had_to_linear( + down_proj, + had_dim=-1, + output=False, + use_fast_had=online_fast, + compute_device=compute_device, + had_matrix=_online_had(intermediate_size), + ) + + # Per-layer cleanup: drop fp64 temporaries and CUDA caching allocator + # blocks so peak memory stays at ~1 layer's worth instead of accumulating + # across all 32+ decoder layers (was the main cause of 33 GB RAM on 8B). + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Unified online hook registration +# --------------------------------------------------------------------------- + + +def _register_online_hooks( + model, + mapping: RotationMapping, + fp32_had: bool = False, + use_fast_had: bool = True, + group_size: int = None, + had_dict: dict = None, + preset: str = None, + fuse_online_to_weight: bool = True, +): + """Register online Hadamard pre-forward hooks on ``down_proj`` and ``o_proj``. + + Online hooks must use the **same** Hadamard matrix that was applied to the + weight-side counterpart during ``_rotate_weights``. For ``quarot_hadamard`` + this is always the deterministic ``get_hadK``/``matmul_hadU`` path + (``use_fast_had=False``). For ``"random_hadamard"`` it is the random matrix that + was generated once and stored in ``had_dict``. + + Args: + group_size: ``None`` → full Hadamard hooks (original QuaRot). + ``int`` → per-group Hadamard hooks. + had_dict: Normalized ``dict[int, Tensor]`` of custom Hadamard matrices. + preset: Rotation preset name. + Returns: + list of hook handles. + """ + config = model.config + num_heads = getattr(config, mapping.num_heads_attr) + hidden_size = getattr(config, mapping.hidden_size_attr) + intermediate_size = getattr(config, mapping.intermediate_size_attr) + head_dim = mapping.attn_head_dim or (hidden_size // num_heads) + + is_grouped = group_size is not None and group_size > 0 + + # Online hooks always use deterministic (fixed) Hadamard — never fast_had + # for quarot_hadamard; for "random_hadamard" they use the same random matrix + # that was cached in had_dict by _rotate_weights. + online_fast = False + + # -- Matrix resolution (must match the *online-paired* matrix used by + # _rotate_weights for down_proj input / OV pair). Variable name kept in + # sync with _rotate_weights to make any future drift obvious. + online_had_matrix, _ = _get_custom_had(had_dict, group_size) if is_grouped else (None, False) + if preset == "random_hadamard" and online_had_matrix is None: + online_had_matrix = get_or_create_random_hadamard(group_size if is_grouped else hidden_size) + if preset == "quarot_hadamard" and is_grouped: + online_had_matrix = None + + # -- Helper: look up cached random matrix for online-paired hooks -- + def _online_had(dim): + if preset == "random_hadamard": + return get_or_create_random_hadamard(dim) + return None + + mlp_out_suffix = mapping.mlp_out.split(".")[-1] + attn_o_suffix = mapping.attn_o.split(".")[-1] + + # Suffixes for Q/K/V and gate/up (for online input Had hooks) + attn_qkv_suffixes = set(attr.split(".")[-1] for attr in (mapping.attn_q, mapping.attn_k, mapping.attn_v)) + mlp_in_suffixes = set(attr.split(".")[-1] for attr in mapping.mlp_in) + + # --- Build hook factories --- + def _make_down_proj_hook(): + if is_grouped: + return GroupOnlineHadamardHook( + group_size=group_size, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_had_matrix + ) + online_mat = _online_had(intermediate_size) + if online_mat is not None: + return FullOnlineHadamardHook( + had_K=None, K=None, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_mat + ) + had_K, K = get_hadK(intermediate_size) + return FullOnlineHadamardHook(had_K=had_K, K=K, fp32_had=fp32_had, use_fast_had=online_fast) + + def _make_hidden_had_hook(): + """Full Had hook on hidden_size (for Q/K/V and gate/up input).""" + if is_grouped: + return GroupOnlineHadamardHook( + group_size=group_size, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_had_matrix + ) + online_mat = _online_had(hidden_size) + if online_mat is not None: + return FullOnlineHadamardHook( + had_K=None, K=None, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_mat + ) + had_K, K = get_hadK(hidden_size) + return FullOnlineHadamardHook(had_K=had_K, K=K, fp32_had=fp32_had, use_fast_had=online_fast) + + def _make_o_proj_hook(): + online_mat = _online_had(num_heads) + if online_mat is not None: + return CrossHeadOnlineHadamardHook( + had_K=None, + K=None, + head_dim=head_dim, + fp32_had=fp32_had, + use_fast_had=online_fast, + had_matrix=online_mat, + ) + had_K, K = get_hadK(num_heads) + return CrossHeadOnlineHadamardHook( + had_K=had_K, + K=K, + head_dim=head_dim, + fp32_had=fp32_had, + use_fast_had=online_fast, + ) + + # --- Register --- + handles = [] + + for name, module in model.named_modules(): + if not isinstance(module, torch.nn.Linear): + continue + suffix = name.split(".")[-1] + + if name.endswith(mlp_out_suffix): + # down_proj: full Had on intermediate_size input + h = module.register_forward_pre_hook(_make_down_proj_hook()) + handles.append(h) + elif name.endswith(attn_o_suffix): + if fuse_online_to_weight and not is_grouped: + # o_proj: cross-head Had on input (fused mode, full only) + h = module.register_forward_pre_hook(_make_o_proj_hook()) + handles.append(h) + elif not fuse_online_to_weight: + # o_proj: full Had on hidden_size input (unfused mode, matches weight rotation) + h = module.register_forward_pre_hook(_make_hidden_had_hook()) + handles.append(h) + elif suffix in attn_qkv_suffixes: + if not fuse_online_to_weight: + # Q/K/V: full Had on hidden_size input (unfused mode only). + # In fused mode Q/K/V only have residual Q on weight (no online Had), + # and activations come pre-rotated from residual stream → no hook needed. + h = module.register_forward_pre_hook(_make_hidden_had_hook()) + handles.append(h) + elif suffix in mlp_in_suffixes: + if not fuse_online_to_weight: + # gate/up: full Had on hidden_size input (unfused mode only). + # Same reasoning as Q/K/V above. + h = module.register_forward_pre_hook(_make_hidden_had_hook()) + handles.append(h) + + return handles + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def apply_rotation_transform( + model, + group_size: int = None, + allow_online_rotation: bool = True, + rotation_matrix: Union[str, torch.Tensor, Dict[int, torch.Tensor], None] = None, + compute_device: torch.device | str = None, + fp32_had: bool = False, + fuse_online_to_weight: bool = None, +): + """Fuse layer norms, rotate weights, and register online Hadamard hooks. + + This is the single entry point for applying Hadamard inplace rotation. + The model architecture is auto-detected via ``model.config.model_type``. + + Args: + model: A HuggingFace CausalLM model (LLaMA-2/3, Qwen-3, etc.). + fp32_had: Whether to compute the online Hadamard transform in fp32. + group_size: If ``None`` (default), use full-dimension Hadamard rotation. + compute_device: Device to run Hadamard computation on. + allow_online_rotation: If ``True`` (default), apply online Hadamard + rotations on ``down_proj`` input and the OV pair. + rotation_matrix: Rotation matrix selection (``"hadamard"``, + ``"random_hadamard"``, ``"quarot_hadamard"``, Tensor, dict, or None). + fuse_online_to_weight: If ``True`` (default), fuse online Hadamard + rotation into weights (down_proj input, v_proj output, o_proj input) + and register compensating online hooks. If ``False``, skip + embedding/lm_head rotation; each linear layer is self-contained + with input-side Had on weight + compensating online hook on + activation. No v_proj cross-head or inner-head rotation. + + Returns: + list of hook handles.""" + if fuse_online_to_weight is None: + if model.config.model_type in MAPPING_REGISTRY or model.__class__.__name__ in MAPPING_REGISTRY: + fuse_online_to_weight = True + else: + fuse_online_to_weight = False + had_dict, use_fast_had, preset = _normalize_rotation_matrix(rotation_matrix, group_size) + compute_device = _resolve_compute_device(compute_device) + + if use_fast_had: + from auto_round.utils import logger + + try: + import fast_hadamard_transform # noqa: F401 + + if group_size is None: + logger.warning( + "fast_hadamard_transform uses a different Hadamard matrix than the " + "default implementation. Please ensure consistency between training " + "and inference. This will be refined later." + ) + except ImportError: + logger.warning("Importing fast_hadamard_transform failed, falling back to default implementation.") + use_fast_had = False + + mapping = infer_mapping_from_model(model) + + _untie_word_embeddings(model, mapping) + + if _uses_layernorm_with_mean(model, mapping): + _subtract_embedding_mean(model, mapping) + + _fuse_layer_norms(model, mapping) + + if _uses_layernorm_with_mean(model, mapping): + layers = _resolve(model, mapping.layers_attr) + for layer in layers: + _bake_mean_into_linear(_resolve(layer, mapping.attn_o)) + _bake_mean_into_linear(_resolve(layer, mapping.mlp_out)) + _replace_layernorms_with_rmsnorm(model) + + _rotate_weights( + model, + mapping, + use_fast_had=use_fast_had, + group_size=group_size, + compute_device=compute_device, + had_dict=had_dict, + preset=preset, + fuse_online_to_weight=fuse_online_to_weight, + ) + + handles = [] + if fuse_online_to_weight or allow_online_rotation: + handles = _register_online_hooks( + model, + mapping, + fp32_had=fp32_had, + use_fast_had=use_fast_had, + group_size=group_size, + had_dict=had_dict, + preset=preset, + fuse_online_to_weight=fuse_online_to_weight, + ) + + return model, handles + + +# --------------------------------------------------------------------------- +# Quick smoke test +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_name = "/models/opt-125m" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") + model.to("cuda") + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + + apply_rotation_transform( + model, group_size=-1, allow_online_rotation=True, rotation_matrix="random_hadamard", fuse_online_to_weight=False + ) + model.to("cuda") + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + + model_name = "/models/Qwen3-8B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") + apply_rotation_transform(model, group_size=-1, allow_online_rotation=True, fuse_online_to_weight=True) + model.to("cuda") + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_name = "/models/Meta-Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") + apply_rotation_transform(model, fuse_online_to_weight=True, group_size=32) + model.to("cuda") + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + # + # model_name = "/models/Llama-2-7b-chat-hf" + # tokenizer = AutoTokenizer.from_pretrained(model_name) + # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") + # apply_hadamard_rotation(model) + # model.to("cuda") + # text = "There is a girl who likes adventure," + # inputs = tokenizer(text, return_tensors="pt").to(model.device) + # print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) diff --git a/auto_round/algorithms/transforms/rotation/inplace/hooks.py b/auto_round/algorithms/transforms/rotation/inplace/hooks.py new file mode 100644 index 000000000..2b3d26c6e --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/inplace/hooks.py @@ -0,0 +1,786 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +"""Online Hadamard transform hooks. + +After weight rotation, down_proj and o_proj require an online Hadamard +transform on their *input activations* at inference time. This module +provides the hooks and a helper to register them on the model. +""" + +import math + +import torch +import torch.nn as nn + +try: + import fast_hadamard_transform +except ImportError: + fast_hadamard_transform = None + + +def _resolve_compute_device(compute_device) -> torch.device: + """Return *compute_device* if explicitly given, otherwise auto-detect GPU. + + When ``compute_device`` is ``None`` the function checks for CUDA / XPU + availability and returns the first accelerator it finds so that heavy + matrix operations are offloaded to GPU even when the model weights live + on CPU. Falls back to ``torch.device("cpu")`` when no accelerator is + present. + """ + if compute_device is not None: + return torch.device(compute_device) if not isinstance(compute_device, torch.device) else compute_device + if torch.cuda.is_available(): + return torch.device("cuda:0") + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu:0") + return torch.device("cpu") + + +BUILTIN_ROTATION_PRESETS = {"quarot_hadamard", "hadamard", "random_hadamard"} + +# Global cache for random Hadamard matrices keyed by dimension. +# Ensures the same shape always returns the exact same random matrix within +# a process, across all calls to ``_rotate_weights`` / ``_register_online_hooks``. +_RANDOM_HADAMARD_CACHE: dict = {} + + +def get_or_create_random_hadamard(dim: int, device=None) -> torch.Tensor: + """Return a random Hadamard matrix for *dim*, creating and caching it if needed. + + The matrix is cached globally in ``_RANDOM_HADAMARD_CACHE`` so that every + caller that requests the same *dim* receives the identical matrix. + """ + if dim in _RANDOM_HADAMARD_CACHE: + mat = _RANDOM_HADAMARD_CACHE[dim] + if device is not None: + mat = mat.to(device) + return mat + mat = random_hadamard_matrix(dim, device or torch.device("cpu")) + _RANDOM_HADAMARD_CACHE[dim] = mat + return mat + + +def clear_random_hadamard_cache(): + """Clear the global random Hadamard matrix cache. + + Call this when you want subsequent ``random_hadamard`` preset runs to + generate fresh random matrices (e.g. between independent experiments). + """ + _RANDOM_HADAMARD_CACHE.clear() + + +def _normalize_rotation_matrix(rotation_matrix, group_size): + """Normalize ``rotation_matrix`` into a ``(had_dict, use_fast_had, preset)`` tuple. + + Accepted inputs: + * ``None`` → ``(None, False, None)`` — use built-in butterfly ``matmul_hadU``. + * ``"quarot_hadamard"`` → ``(None, True, "quarot_hadamard")`` — fusable + rotations use ``fast_hadamard_transform`` (random); non-fusable + (online-paired) rotations use deterministic ``get_hadK``/``matmul_hadU``. + * ``"hadamard"`` → ``(None, False, "hadamard")`` — all rotations use + deterministic ``get_hadK``/``matmul_hadU``. + * ``"random_hadamard"`` → ``(None, False, "random_hadamard")`` — all rotations use + ``random_hadamard_matrix``. + * A ``torch.Tensor`` of shape ``(n, n)`` → ``({n: tensor}, False, None)``. + * A ``dict[int, Tensor]`` → ``(dict, False, None)`` — returned as-is. + + Returns: + ``(had_dict, use_fast_had, preset)`` + + Raises: + ValueError: if a non-``str`` *rotation_matrix* is given but + *group_size* is not a positive integer, or an unknown preset. + """ + if rotation_matrix is None: + return None, False, None + + if isinstance(rotation_matrix, str): + if rotation_matrix not in BUILTIN_ROTATION_PRESETS: + raise ValueError( + f"Unknown rotation_matrix preset '{rotation_matrix}'. " + f"Supported presets: {BUILTIN_ROTATION_PRESETS}." + ) + if rotation_matrix == "quarot_hadamard": + return None, True, "quarot_hadamard" + elif rotation_matrix == "hadamard": + return None, False, "hadamard" + else: # "random_hadamard" + return None, False, "random_hadamard" + + is_grouped = group_size is not None and group_size > 0 + if not is_grouped and not isinstance(rotation_matrix, dict): + raise ValueError( + "rotation_matrix (Tensor/dict) can only be used with a positive group_size. " + f"Got group_size={group_size}." + ) + + if isinstance(rotation_matrix, torch.Tensor): + assert ( + rotation_matrix.ndim == 2 and rotation_matrix.shape[0] == rotation_matrix.shape[1] + ), f"rotation_matrix must be square, got shape {rotation_matrix.shape}" + return {rotation_matrix.shape[0]: rotation_matrix}, False, None + + if isinstance(rotation_matrix, dict): + for k, t in rotation_matrix.items(): + assert ( + isinstance(t, torch.Tensor) and t.ndim == 2 and t.shape[0] == t.shape[1] + ), f"rotation_matrix[{k}] must be a square tensor, got shape {t.shape}" + return rotation_matrix, False, None + + raise TypeError( + f"rotation_matrix must be a Tensor, dict[int, Tensor], str, or None. " f"Got {type(rotation_matrix)}." + ) + + +def _get_custom_had(had_dict, size): + """Look up a custom Hadamard matrix for *size* from the normalized dict. + + Returns ``(had_tensor, True)`` if found, ``(None, False)`` otherwise. + """ + if had_dict is None: + return None, False + if size in had_dict: + return had_dict[size], True + return None, False + + +# --------------------------------------------------------------------------- +# Hook implementations +# --------------------------------------------------------------------------- + + +class FullOnlineHadamardHook(nn.Module): + """Pre-forward hook: full Hadamard on the entire last dimension (for ``down_proj``).""" + + def __init__(self, had_K, K, fp32_had=False, use_fast_had=True, had_matrix=None): + super().__init__() + self.custom_had = had_matrix is not None + if had_matrix is not None: + self.register_buffer("had_matrix", had_matrix) + self.had_K = None + self.K = None + else: + if had_K is not None: + self.register_buffer("had_K", had_K) + else: + self.had_K = None + self.K = K + self.fp32_had = fp32_had + self.use_fast_had = use_fast_had + + def __call__(self, module: nn.Module, args): + x = args[0] if isinstance(args, tuple) else args + x_dtype = x.dtype + + if self.custom_had: + H = self.had_matrix.to(device=x.device, dtype=x.dtype) + if self.fp32_had: + H = self.had_matrix.to(device=x.device).float() + x = (x.float() @ H.T).to(x_dtype) + else: + x = x @ H.T + elif self.fp32_had: + x = matmul_hadU_cuda(x.float(), self.had_K, self.K, use_fast_had=self.use_fast_had).to(x_dtype) + else: + x = matmul_hadU_cuda(x, self.had_K, self.K, use_fast_had=self.use_fast_had) + + if isinstance(args, tuple): + return (x,) + args[1:] + return x + + +class CrossHeadOnlineHadamardHook(nn.Module): + """Pre-forward hook: **cross-head** Hadamard on the ``num_heads`` dimension + (for ``o_proj``). + + After offline rotation: + - ``v_proj`` absorbed a per-head (within-head) Hadamard on ``head_dim``. + - ``o_proj`` absorbed a full Hadamard on ``hidden_size``. + + Since ``H_full = H_cross ⊗ H_within`` (Kronecker decomposition) and the + within-head part is already cancelled by ``v_proj`` through the attention + path (``H_within² = I``), the online hook only needs to apply the residual + **cross-head** Hadamard (``H_cross ⊗ I``): + + * reshape ``(*, hidden_size)`` → ``(*, num_heads, head_dim)`` + * transpose → ``(*, head_dim, num_heads)`` + * Hadamard on the **num_heads** axis (last dim) + * transpose back and reshape + """ + + def __init__(self, had_K, K, head_dim, fp32_had=False, use_fast_had=True, had_matrix=None): + """ + Args: + had_K: Hadamard sub-matrix from ``get_hadK(num_heads)``. + K: Block size from ``get_hadK(num_heads)``. + head_dim: ``hidden_size // num_attention_heads``. + fp32_had: Compute in fp32. + use_fast_had: If True use fast_hadamard_transform; if False use matmul_hadU. + had_matrix: Optional custom rotation matrix of shape ``(num_heads, num_heads)``. + """ + super().__init__() + self.custom_had = had_matrix is not None + if had_matrix is not None: + self.register_buffer("had_matrix", had_matrix) + self.had_K = None + self.K = None + else: + if had_K is not None: + self.register_buffer("had_K", had_K) + else: + self.had_K = None + self.K = K + self.had_dim = head_dim + self.fp32_had = fp32_had + self.use_fast_had = use_fast_had + + def __call__(self, module: nn.Module, args): + x = args[0] if isinstance(args, tuple) else args + x_dtype = x.dtype + + if self.fp32_had: + x = x.float() + + init_shape = x.shape + num_heads = init_shape[-1] // self.had_dim + + if self.custom_had: + H = self.had_matrix.to(device=x.device, dtype=x.dtype) + # reshape (*, hidden) → (*, num_heads, head_dim), transpose → (*, head_dim, num_heads) + x = x.reshape(-1, num_heads, self.had_dim).transpose(1, 2) + # apply H on last dim (num_heads): x @ H.T + x = (x @ H.T).transpose(1, 2) + elif self.use_fast_had and fast_hadamard_transform is not None and self.K == 1: + x = fast_hadamard_transform.hadamard_transform( + x.reshape(-1, num_heads, self.had_dim).transpose(1, 2), + scale=1 / math.sqrt(num_heads), + ).transpose(1, 2) + else: + # Fallback: use matmul_hadU (pure butterfly + had_K, no fast_hadamard_transform) + x = x.reshape(-1, num_heads, self.had_dim).transpose(1, 2) + x = matmul_hadU(x.contiguous()) + x = x.transpose(1, 2) + + if self.fp32_had: + x = x.to(x_dtype) + x = x.reshape(init_shape) + + if isinstance(args, tuple): + return (x,) + args[1:] + return x + + +# --------------------------------------------------------------------------- +# Registration helper +# --------------------------------------------------------------------------- + + +def register_online_had_hooks(model, mapping=None, fp32_had=False, use_fast_had=True): + """Register online Hadamard pre-forward hooks on ``down_proj`` and ``o_proj``. + + * **down_proj** (``online_full_had``): full Hadamard on ``intermediate_size``. + Compensates ``apply_exact_had_to_linear(down_proj, had_dim=-1, output=False)``. + + * **o_proj** (``online cross-head had``): cross-head Hadamard on ``num_heads``. + Compensates the residual after v_proj's within-head Hadamard cancels. + + Args: + model: A HuggingFace model whose weights have already been rotated. + mapping: A :class:`RotationMapping` (auto-inferred if ``None``). + fp32_had: Whether to compute the Hadamard transform in fp32. + use_fast_had: If True use fast_hadamard_transform; if False use matmul_hadU. + + Returns: + list of hook handles (call ``handle.remove()`` to detach). + """ + if mapping is None: + from auto_round.algorithms.transforms.rotation.inplace.model_config import infer_mapping_from_model + + mapping = infer_mapping_from_model(model) + + config = model.config + num_heads = getattr(config, mapping.num_heads_attr) + hidden_size = getattr(config, mapping.hidden_size_attr) + intermediate_size = getattr(config, mapping.intermediate_size_attr) + head_dim = mapping.attn_head_dim or (hidden_size // num_heads) + + # down_proj: full Hadamard on intermediate_size + had_K_full, K_full = get_hadK(intermediate_size) + + # o_proj: cross-head Hadamard on num_heads + had_K_head, K_head = get_hadK(num_heads) + + # Identify target module suffixes from mapping + mlp_out_suffix = mapping.mlp_out.split(".")[-1] # e.g. "down_proj" + attn_o_suffix = mapping.attn_o.split(".")[-1] # e.g. "o_proj" + + handles = [] + for name, module in model.named_modules(): + if name.endswith(mlp_out_suffix) and isinstance(module, nn.Linear): + hook = FullOnlineHadamardHook( + had_K=had_K_full, + K=K_full, + fp32_had=fp32_had, + use_fast_had=use_fast_had, + ) + h = module.register_forward_pre_hook(hook) + handles.append(h) + elif name.endswith(attn_o_suffix) and isinstance(module, nn.Linear): + hook = CrossHeadOnlineHadamardHook( + had_K=had_K_head, + K=K_head, + head_dim=head_dim, + fp32_had=fp32_had, + use_fast_had=use_fast_had, + ) + h = module.register_forward_pre_hook(hook) + handles.append(h) + + return handles + + +def is_pow2(n): + return (n & (n - 1) == 0) and (n > 0) + + +# Adapted from https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py +def get_hadK(n: int, transpose=False) -> (torch.Tensor, int): + hadK, K = None, None + + if is_pow2(n): + K = 1 + return hadK, K + else: + from auto_round.algorithms.transforms.rotation.utils.math import _fetch_hadamard_divisor + + hadK = _fetch_hadamard_divisor(n, torch.float, torch.device("cpu")) + if transpose: + hadK = hadK.T + if hadK is not None: + return hadK, 1 if is_pow2(hadK.shape[0]) else hadK.shape[0] + assert is_pow2(n) + + +def matmul_hadU(X, transpose=False): + n = X.shape[-1] + hadK, K = get_hadK(n, transpose) + input = X.clone().view(-1, n, 1) + output = input.clone() + while input.shape[1] > K: + input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) + output = output.view(input.shape) + output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] + output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] + output = output.view(input.shape[0], input.shape[1], -1) + input, output = (output, input) + del output + + if K > 1: + # Do not explicitly repeat - OOM + # input = torch.bmm( + # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) + # Use bcast instead + input = hadK.view(1, K, K).to(input) @ input + + return input.view(X.shape) / torch.tensor(n).sqrt() + + +def matmul_hadUt(X): + return matmul_hadU(X, transpose=True) + + +def random_hadamard_matrix(size, device): + # See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation" + Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return matmul_hadU(Q).to(device) + + +def deterministic_hadamard_matrix(size, device): + """Build a deterministic Hadamard matrix of the given *size*. + + Applies the butterfly ``matmul_hadU`` to an identity matrix so that the + result is purely determined by ``get_hadK`` (no random sign flips). + """ + Q = torch.eye(size, dtype=torch.float64) + return matmul_hadU(Q).to(device) + + +def matmul_hadU_cuda(X, hadK, K, use_fast_had=True): + n = X.shape[-1] + if not use_fast_had or fast_hadamard_transform is None: + return matmul_hadU(X) + if K == 1: + return fast_hadamard_transform.hadamard_transform(X.contiguous(), 1.0 / torch.tensor(n).sqrt()) + # if transpose: + # hadK = hadK.T.contiguous() + input = X.view(*X.shape[:-1], K, n // K) + input = fast_hadamard_transform.hadamard_transform(input.contiguous(), 1.0 / torch.tensor(n).sqrt()) + input = hadK.to(input.device).to(input.dtype) @ input + return input.reshape(X.shape) + + +def matmul_hadUt_cuda(X, hadK, K, use_fast_had=True): + return matmul_hadU_cuda(X, hadK, K, use_fast_had=use_fast_had) + + +def apply_exact_had_to_linear( + module, had_dim=-1, output=False, use_fast_had=True, compute_device=None, had_matrix=None +): + """Apply Hadamard rotation to a Linear layer's weight in-place. + + Args: + module: ``nn.Linear`` layer. + had_dim: Dimension of each Hadamard block (``-1`` for full dimension). + output: If ``True`` rotate the output (row) side; otherwise input (col). + use_fast_had: Use ``fast_hadamard_transform`` when available. + compute_device: Device to run computation on. + had_matrix: Optional custom rotation matrix. When ``had_dim == -1`` + this should be a square tensor whose size equals + ``out_features`` (output) or ``in_features`` (input). When + ``had_dim > 0`` the size should equal ``had_dim``. + """ + assert isinstance(module, torch.nn.Linear) + in_features, out_features = module.in_features, module.out_features + + if had_dim != -1 and had_matrix is None: + assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!" + + W_ = module.weight.data + dtype = W_.dtype + dev = W_.device + init_shape = W_.shape + compute_dev = _resolve_compute_device(compute_device) + W_ = W_.double().to(compute_dev) + + if had_matrix is not None: + H = had_matrix.to(device=compute_dev, dtype=torch.float64) + if had_dim == -1: + # Full-dimension custom matrix + if output: + # W.T = H @ W.T → W = (H @ W.T).T = W @ H.T + W_ = W_ @ H.T + else: + # W = H @ W (rotate input columns: W_new[i,:] = sum H[i,j]*W[j,:]) + # Actually for input side: W_new = W @ H (each row is rotated) + W_ = W_ @ H.T + else: + # Per-block custom matrix + if output: + W_ = W_.t() + transposed_shape = W_.shape + flat = W_.reshape(-1, had_dim) + W_ = (flat @ H.T).reshape(transposed_shape).t() + else: + flat = W_.reshape(-1, had_dim) + W_ = (flat @ H.T).reshape(init_shape) + elif had_dim == -1: + if output: + had_K, K = get_hadK(out_features) + W_ = matmul_hadU_cuda(W_.t(), had_K, K, use_fast_had=use_fast_had).t() + if not output: + had_K, K = get_hadK(in_features) + W_ = matmul_hadU_cuda(W_, had_K, K, use_fast_had=use_fast_had) + else: + # Apply Hadamard to the last had_dim chunks of the weights + if output: + W_ = W_.t() + transposed_shape = W_.shape + if use_fast_had and fast_hadamard_transform is not None: + W_ = ( + fast_hadamard_transform.hadamard_transform( + W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim), scale=1 / math.sqrt(had_dim) + ) + .reshape(transposed_shape) + .t() + ) + else: + W_ = matmul_hadU(W_.reshape(-1, had_dim)).reshape(transposed_shape).t() + else: + if use_fast_had and fast_hadamard_transform is not None: + n = W_.shape[1] + W_ = fast_hadamard_transform.hadamard_transform( + W_.reshape(-1, n // had_dim, had_dim), scale=1 / math.sqrt(had_dim) + ).reshape(init_shape) + else: + W_ = matmul_hadU(W_.reshape(-1, had_dim)).reshape(init_shape) + module.weight.data = W_.to(device=dev, dtype=dtype) + + +def apply_cross_head_had_to_linear( + module, num_heads, head_dim, use_fast_had=True, compute_device=None, had_matrix=None +): + """Apply a cross-head Hadamard rotation to a Linear layer's input side. + + The operation is equivalent to ``(H_cross ⊗ I_head_dim)`` applied to the + input columns: + + * Reshape columns ``(hidden_size,)`` → ``(num_heads, head_dim)`` + * Transpose → ``(head_dim, num_heads)`` + * Hadamard on the ``num_heads`` axis + * Transpose back and reshape + + This mirrors what :class:`CrossHeadOnlineHadamardHook` does at runtime. + + Args: + module: ``nn.Linear`` layer whose ``in_features == num_heads * head_dim``. + num_heads: Number of attention heads. + head_dim: Per-head dimension. + use_fast_had: Use ``fast_hadamard_transform`` when available. + compute_device: Device to run computation on. + had_matrix: Optional custom rotation matrix of shape ``(num_heads, num_heads)``. + """ + assert isinstance(module, torch.nn.Linear) + W_ = module.weight.data + dtype = W_.dtype + dev = W_.device + compute_dev = _resolve_compute_device(compute_device) + W_ = W_.double().to(compute_dev) + + out_f = W_.shape[0] + # W shape: (out_features, hidden_size) where hidden_size = num_heads * head_dim + # Reshape columns: (out_f, num_heads, head_dim) + W_ = W_.reshape(out_f, num_heads, head_dim) + # Transpose last two dims: (out_f, head_dim, num_heads) + W_ = W_.transpose(1, 2).contiguous() + + if had_matrix is not None: + H = had_matrix.to(device=compute_dev, dtype=torch.float64) + # Apply H on last dim (num_heads): flat @ H.T + flat = W_.reshape(-1, num_heads) + W_ = (flat @ H.T).reshape(out_f, head_dim, num_heads) + elif use_fast_had and fast_hadamard_transform is not None and is_pow2(num_heads): + W_ = fast_hadamard_transform.hadamard_transform(W_, scale=1.0 / math.sqrt(num_heads)) + else: + W_ = matmul_hadU(W_.reshape(-1, num_heads)).reshape(out_f, head_dim, num_heads) + + # Transpose back: (out_f, num_heads, head_dim) → (out_f, hidden_size) + W_ = W_.transpose(1, 2).contiguous().reshape(out_f, num_heads * head_dim) + module.weight.data = W_.to(device=dev, dtype=dtype) + + +# --------------------------------------------------------------------------- +# Grouped (block-diagonal) Hadamard utilities +# --------------------------------------------------------------------------- + + +class OnlineHadamardPostHook(nn.Module): + """Forward hook (post-hook) adapter: wraps a pre-hook-style Hadamard + transform to apply it on the layer's **output** instead of input. + + Used for v_proj per-head Hadamard on the output side when online + rotation is not fused into weights. + """ + + def __init__(self, pre_hook): + super().__init__() + self.pre_hook = pre_hook + + def __call__(self, module, input, output): + result = self.pre_hook(module, (output,)) + if isinstance(result, tuple): + return result[0] + return result + + +class GroupOnlineHadamardHook(nn.Module): + """Pre-forward hook: block-diagonal Hadamard with fixed ``group_size`` on last dim. + + Reshapes ``(*, D)`` → ``(*, D // group_size, group_size)``, applies Hadamard + per group, then reshapes back. Much cheaper than a full-dimension Hadamard. + """ + + def __init__(self, group_size, fp32_had=False, use_fast_had=True, had_matrix=None): + super().__init__() + self.group_size = group_size + self.fp32_had = fp32_had + self.use_fast_had = use_fast_had + self.custom_had = had_matrix is not None + + if had_matrix is not None: + self.register_buffer("had_matrix", had_matrix) + self.had_K = None + self.K = None + elif not is_pow2(group_size): + had_K, K = get_hadK(group_size) + if had_K is not None: + self.register_buffer("had_K", had_K) + else: + self.had_K = None + self.K = K + else: + self.had_K = None + self.K = 1 + + def __call__(self, module: nn.Module, args): + x = args[0] if isinstance(args, tuple) else args + x_dtype = x.dtype + init_shape = x.shape + gs = self.group_size + + if self.fp32_had: + x = x.float() + + # Reshape: (*, D) → (*, D//gs, gs) + x = x.reshape(*init_shape[:-1], init_shape[-1] // gs, gs) + + if self.custom_had: + H = self.had_matrix.to(device=x.device, dtype=x.dtype) + flat = x.reshape(-1, gs) + x = (flat @ H.T).reshape(*init_shape[:-1], init_shape[-1] // gs, gs) + elif self.use_fast_had and fast_hadamard_transform is not None and self.K == 1: + x = fast_hadamard_transform.hadamard_transform(x, scale=1.0 / math.sqrt(gs)) + else: + x = x.reshape(-1, gs) + x = matmul_hadU(x) + x = x.reshape(*init_shape[:-1], init_shape[-1] // gs, gs) + + x = x.reshape(init_shape) + + if self.fp32_had: + x = x.to(x_dtype) + + if isinstance(args, tuple): + return (x,) + args[1:] + return x + + +def _apply_grouped_had_to_weight(W, group_size, side="input", use_fast_had=True, had_matrix=None): + """Apply block-diagonal Hadamard to a weight matrix. + + Args: + W: Weight tensor, shape (out_features, in_features). + group_size: Block size for the Hadamard rotation. + side: ``'input'`` rotates columns (in_features dim), + ``'output'`` rotates rows (out_features dim). + use_fast_had: Use fast_hadamard_transform if available. + had_matrix: Optional custom Hadamard matrix of shape ``(gs, gs)`` + to use instead of the built-in Hadamard. + + Returns: + Rotated weight tensor. + """ + gs = group_size + dtype = W.dtype + W = W.double() + + def _had_on_last_dim(X): + """Apply Hadamard on the last dimension (size gs) of X shaped (..., gs).""" + if had_matrix is not None: + H = had_matrix.to(device=X.device, dtype=X.dtype) + # X: (..., gs) → batch matmul with H^T → X @ H^T + flat = X.reshape(-1, gs) + return (flat @ H.T).reshape(X.shape) + if use_fast_had and fast_hadamard_transform is not None and is_pow2(gs): + return fast_hadamard_transform.hadamard_transform(X, scale=1.0 / math.sqrt(gs)) + orig_shape = X.shape + return matmul_hadU(X.reshape(-1, gs)).reshape(orig_shape) + + if side == "input": + out_f, in_f = W.shape + W = W.reshape(out_f, in_f // gs, gs) + W = _had_on_last_dim(W) + W = W.reshape(out_f, in_f) + else: + out_f, in_f = W.shape + Wt = W.t().contiguous() + Wt = Wt.reshape(in_f, out_f // gs, gs) + Wt = _had_on_last_dim(Wt) + W = Wt.reshape(in_f, out_f).t().contiguous() + + return W.to(dtype) + + +def _rotate_linear_grouped(module, group_size, side="input", use_fast_had=True, compute_device=None, had_matrix=None): + """Apply block-diagonal Hadamard rotation to a Linear layer's weight. + + Args: + module: ``nn.Linear`` layer. + group_size: Block size. + side: ``'input'`` or ``'output'``. + use_fast_had: Use fast_hadamard_transform. + compute_device: Device to run computation on. If None, auto-detects GPU. + had_matrix: Optional custom Hadamard matrix of shape ``(gs, gs)``. + """ + dtype = module.weight.data.dtype + dev = module.weight.data.device + compute_dev = _resolve_compute_device(compute_device) + W = module.weight.data.to(device=compute_dev, dtype=torch.float64) + W = _apply_grouped_had_to_weight(W, group_size, side=side, use_fast_had=use_fast_had, had_matrix=had_matrix) + module.weight.data = W.to(device=dev, dtype=dtype) + + if side == "output" and module.bias is not None: + bias = module.bias.data.to(device=compute_dev, dtype=torch.float64) + gs = group_size + bias = bias.reshape(-1, gs) + if had_matrix is not None: + H = had_matrix.to(device=compute_dev, dtype=torch.float64) + bias = (bias @ H.T).reshape(-1) + elif use_fast_had and fast_hadamard_transform is not None and is_pow2(gs): + bias = ( + fast_hadamard_transform.hadamard_transform(bias.unsqueeze(0), scale=1.0 / math.sqrt(gs)) + .squeeze(0) + .reshape(-1) + ) + else: + bias = matmul_hadU(bias).reshape(-1) + module.bias.data = bias.to(device=dev, dtype=dtype) + + +def _rotate_embedding_grouped(embedding, group_size, use_fast_had=True, compute_device=None, had_matrix=None): + """Apply block-diagonal Hadamard rotation to an Embedding layer. + + Embedding weight: (vocab, hidden_size) → rotate on hidden_size (columns). + """ + dtype = embedding.weight.data.dtype + dev = embedding.weight.data.device + compute_dev = _resolve_compute_device(compute_device) + W = embedding.weight.data.to(device=compute_dev, dtype=torch.float64) + W = _apply_grouped_had_to_weight(W, group_size, side="input", use_fast_had=use_fast_had, had_matrix=had_matrix) + new_W = W.to(device=dev, dtype=dtype) + del W + embedding.weight.data = new_W + + +def register_online_had_hooks_grouped(model, mapping, group_size, fp32_had=False, use_fast_had=True): + """Register per-group online Hadamard hooks on ``down_proj`` and ``o_proj``. + + In grouped mode: + - **down_proj**: block-diagonal Hadamard on ``intermediate_size`` with ``group_size``. + - **o_proj**: block-diagonal Hadamard on ``hidden_size`` with ``group_size``. + + Args: + model: HuggingFace model with rotated weights. + mapping: RotationMapping. + group_size: Block size for block-diagonal Hadamard. + fp32_had: Compute in fp32. + use_fast_had: Use fast_hadamard_transform. + + Returns: + list of hook handles. + """ + mlp_out_suffix = mapping.mlp_out.split(".")[-1] + attn_o_suffix = mapping.attn_o.split(".")[-1] + + handles = [] + for name, module in model.named_modules(): + if name.endswith(mlp_out_suffix) and isinstance(module, nn.Linear): + hook = GroupOnlineHadamardHook( + group_size=group_size, + fp32_had=fp32_had, + use_fast_had=use_fast_had, + ) + h = module.register_forward_pre_hook(hook) + handles.append(h) + elif name.endswith(attn_o_suffix) and isinstance(module, nn.Linear): + hook = GroupOnlineHadamardHook( + group_size=group_size, + fp32_had=fp32_had, + use_fast_had=use_fast_had, + ) + h = module.register_forward_pre_hook(hook) + handles.append(h) + + return handles diff --git a/auto_round/algorithms/transforms/rotation/inplace/model_config.py b/auto_round/algorithms/transforms/rotation/inplace/model_config.py new file mode 100644 index 000000000..3ecbf9b69 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/inplace/model_config.py @@ -0,0 +1,169 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +"""Model architecture mapping for Hadamard rotation. + +Each :class:`RotationMapping` describes *where* the rotation-relevant modules +live inside a model. Currently supports LLaMA-2, LLaMA-3, and Qwen-3 (dense). + +New architectures can be supported by calling :func:`register_mapping`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +from auto_round.utils import logger + +__all__ = [ + "RotationMapping", + "register_mapping", + "get_mapping", + "infer_mapping_from_model", + "MAPPING_REGISTRY", +] + + +# --------------------------------------------------------------------------- +# Mapping dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class RotationMapping: + """Declarative description of a transformer architecture for Hadamard rotation. + + Attribute names follow the dot-path convention relative to the model or + each decoder layer. + + Config attribute names (read from ``model.config``): + num_heads_attr, hidden_size_attr, intermediate_size_attr + head_dim_override – explicit head dim (skip hidden_size // num_heads) + """ + + # -- top-level modules (dot-path from model root) -- + embedding: str = "model.embed_tokens" + lm_head: str = "lm_head" + positional_embedding: Optional[str] = None # e.g. "model.decoder.embed_positions" for OPT + + # -- layers container (dot-path from model root) -- + layers_attr: str = "model.layers" + + # -- per-layer: attention (dot-path from each layer) -- + attn_input_ln: str = "input_layernorm" + attn_q: str = "self_attn.q_proj" + attn_k: str = "self_attn.k_proj" + attn_v: str = "self_attn.v_proj" + attn_o: str = "self_attn.o_proj" + + # -- per-layer: MLP (dot-path from each layer) -- + mlp_input_ln: str = "post_attention_layernorm" + mlp_in: List[str] = field(default_factory=lambda: ["mlp.up_proj", "mlp.gate_proj"]) + mlp_out: str = "mlp.down_proj" + + # -- final norm (dot-path from model root) -- + pre_head_ln: str = "model.norm" + + # -- head dim override (None = hidden_size // num_heads) -- + attn_head_dim: Optional[int] = None + + # -- config attr names -- + num_heads_attr: str = "num_attention_heads" + hidden_size_attr: str = "hidden_size" + intermediate_size_attr: str = "intermediate_size" + + +# --------------------------------------------------------------------------- +# Helper: resolve a dot-path attribute on a module +# --------------------------------------------------------------------------- + + +def _resolve(root, dot_path: str): + """Resolve ``'a.b.c'`` to ``root.a.b.c``.""" + obj = root + for attr in dot_path.split("."): + obj = getattr(obj, attr) + return obj + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +MAPPING_REGISTRY: Dict[str, RotationMapping] = {} + + +def register_mapping(key: str, mapping: RotationMapping) -> RotationMapping: + """Register a :class:`RotationMapping` under *key* (model_type or architecture).""" + MAPPING_REGISTRY[key] = mapping + return mapping + + +def get_mapping(key: str) -> RotationMapping: + """Look up a mapping by *key*; fall back to default if not found.""" + if key in MAPPING_REGISTRY: + return MAPPING_REGISTRY[key] + logger.warning(f"No rotation mapping registered for '{key}', " "falling back to default (LLaMA-like) mapping.") + return RotationMapping() + + +def infer_mapping_from_model(model) -> RotationMapping: + """Return the best :class:`RotationMapping` for *model*. + + Tries ``model.config.model_type`` first, then ``model.__class__.__name__``. + """ + model_type = getattr(getattr(model, "config", None), "model_type", "") + if model_type in MAPPING_REGISTRY: + return MAPPING_REGISTRY[model_type] + + arch = model.__class__.__name__ + if arch in MAPPING_REGISTRY: + return MAPPING_REGISTRY[arch] + + logger.warning( + f"Unrecognised architecture '{arch}' (model_type='{model_type}'). " + "Falling back to default (LLaMA-like) mapping." + ) + return RotationMapping() + + +# =================================================================== +# Built-in mappings +# =================================================================== + +# LLaMA-2 / LLaMA-3 / Mistral / Yi — all share the same layout +_default = RotationMapping() + +register_mapping("llama", _default) +register_mapping("LlamaForCausalLM", _default) + +# Qwen-3 dense — identical layout to LLaMA +register_mapping("qwen3", _default) +register_mapping("Qwen3ForCausalLM", _default) + +# Qwen-2 / Qwen-2.5 dense — identical layout to LLaMA +register_mapping("qwen2", _default) +register_mapping("Qwen2ForCausalLM", _default) + +# ---- OPT ---- +# OPT uses standard LayerNorm (with bias, subtracts mean), +# different module names, and tied lm_head ↔ embedding weights. +_opt = RotationMapping( + embedding="model.decoder.embed_tokens", + lm_head="lm_head", + positional_embedding="model.decoder.embed_positions", + layers_attr="model.decoder.layers", + attn_input_ln="self_attn_layer_norm", + attn_q="self_attn.q_proj", + attn_k="self_attn.k_proj", + attn_v="self_attn.v_proj", + attn_o="self_attn.out_proj", + mlp_input_ln="final_layer_norm", + mlp_in=["fc1"], + mlp_out="fc2", + pre_head_ln="model.decoder.final_layer_norm", + intermediate_size_attr="ffn_dim", +) +register_mapping("opt", _opt) +register_mapping("OPTForCausalLM", _opt) diff --git a/auto_round/algorithms/transforms/rotation/patch.py b/auto_round/algorithms/transforms/rotation/patch.py new file mode 100644 index 000000000..632ec9682 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/patch.py @@ -0,0 +1,197 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Monkey-patching helpers to inject Hadamard transforms into calibration wrappers. + +During AutoRound calibration (``need_calibration=True``) the weight is re- +quantised at every forward pass. These patches insert the Hadamard rotation +into :class:`~auto_round.wrapper.WrapperLinear` and +:class:`~auto_round.wrapper.WrapperWALayer` so the transform is applied +transparently inside the tuning loop. + +Each patch is idempotent: calling it twice has no effect. +""" + +from __future__ import annotations + +import torch +import transformers + +from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear, pack_fp4_to_uint8 +from auto_round.wrapper import WrapperLinear, WrapperWALayer + +__all__ = [ + "patch_wrapperlinear_to_apply_transform", + "patch_wrapperwalayer_forward_to_apply_transform", + "patch_quantlinear", +] + + +def patch_wrapperlinear_to_apply_transform( + w_transform: torch.nn.Module, + inp_transform: torch.nn.Module, +) -> None: + """Inject *w_transform* and *inp_transform* into :class:`WrapperLinear`. + + After this call, every ``WrapperLinear`` instance will: + + * Apply *w_transform* to the weight before quantisation (``_qdq_weight``). + * Apply *inp_transform* to the activation before quantisation (``_qdq_act``). + + The patch is written at the **class** level and is therefore global – it + affects all future instances as well. A guard flag ``_hadamard_patched`` + prevents double-patching. + """ + if getattr(WrapperLinear, "_hadamard_patched", False): + return + + _orig_qdq_weight = WrapperLinear._qdq_weight + + def _qdq_weight_patched(self, value, min_scale, max_scale): + if self.orig_layer.bits >= 16: + # Keep original behaviour for >=16-bit quantisation. + return _orig_qdq_weight(self, value, min_scale, max_scale) + + if getattr(self, "applied_weight_hadamard", None) is None: + with torch.no_grad(): + weight = self.orig_layer.weight + if weight.device.type == "meta": + weight = self.orig_layer.get_weight().to(self.device) + + is_conv1d = type(self.orig_layer) is transformers.pytorch_utils.Conv1D + if is_conv1d: + weight = weight.t().contiguous() + new_weight = w_transform(weight).to(self.device) + if is_conv1d: + new_weight = new_weight.t().contiguous() + self.orig_layer.weight.data.copy_(new_weight) + self.applied_weight_hadamard = True + + return _orig_qdq_weight(self, value, min_scale, max_scale) + + _orig_qdq_act = WrapperLinear._qdq_act + + def _qdq_act_patched(self, x, act_min_scale=torch.tensor(1.0), act_max_scale=torch.tensor(1.0), act_max=None): + x = inp_transform(x) + + return _orig_qdq_act(self, x, act_min_scale=act_min_scale, act_max_scale=act_max_scale, act_max=act_max) + + WrapperLinear._qdq_weight = _qdq_weight_patched + WrapperLinear._qdq_act = _qdq_act_patched + WrapperLinear._hadamard_patched = True + + +def patch_wrapperwalayer_forward_to_apply_transform( + inp_transform: torch.nn.Module, +) -> None: + """Inject *inp_transform* into :class:`WrapperWALayer`.forward. + + After this call every ``WrapperWALayer`` will rotate its input activation + before the activation quantisation step. Idempotent via + ``_hadamard_forward_patched`` guard. + """ + if getattr(WrapperWALayer, "_hadamard_forward_patched", False): + return + + _orig_forward = WrapperWALayer.forward + + def _forward_patched(self, x): + act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None + x = inp_transform(x) + x, _, _ = self.orig_layer.act_quant_func( + x, + bits=self.orig_layer.act_bits, + group_size=self.orig_layer.act_group_size, + scale_dtype=self.orig_layer.scale_dtype, + q_scale_thresh=self.orig_layer.q_scale_thresh, + data_type=self.orig_layer.act_data_type, + tensor_max=act_max, + ) + return self.orig_layer.forward(x) + + WrapperWALayer.forward = _forward_patched + WrapperWALayer._hadamard_forward_patched = True + + +def patch_quantlinear(w_transform) -> None: + """Patch :class:`QuantLinear` so random Hadamard matrices are saved when packing. + + Only needed for ``random_hadamard`` where the rotation matrix must be + serialised alongside the quantised weights for correct inference. + Idempotent via ``_pack_patched`` guard. + """ + if getattr(QuantLinear, "_pack_patched", False): + return + + from auto_round.data_type.nvfp import cast_to_fp4, get_reciprocal + from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad + from auto_round.utils import get_packing_device + + E8M0_EXPONENT_BIAS = 127 + E8M0_EXPONENT_NAN_VAL = 255 + + def _pack_patched( + self, + linear, + scales, + zeros=None, + g_idx=None, + global_scale=None, + input_global_scale=None, + device=None, + ): + device = get_packing_device(device) + if getattr(linear, "bias", None) is not None: + self.bias = linear.bias.detach().to(torch.float16) + + W = linear.weight.data.detach().to(device) + if type(linear) is torch.nn.Conv2d: + W = W.flatten(1) + if type(linear) is transformers.pytorch_utils.Conv1D: + W = W.t() + + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(W, self.group_size) + scales = scales.to(device) + if self.is_nv: + assert global_scale is not None and global_scale.numel() == 1 + global_scale = global_scale.reshape([1]).to(device) + scaled_tensor = tensor.to(global_scale.dtype) * get_reciprocal( + scales.reshape(tensor.shape[0], -1) * get_reciprocal(global_scale) + ) + scaled_tensor.clamp_(-6.0, 6.0) + scaled_tensor = cast_to_fp4(scaled_tensor) + else: + scaled_tensor = tensor / (2 ** scales.reshape(tensor.shape[0], -1)) + scaled_tensor = revert_tensor_by_pad(scaled_tensor, orig_shape=orig_shape, pad_len=pad_len) + if self.is_mx: + final_scale = (scales + E8M0_EXPONENT_BIAS).clamp(0, E8M0_EXPONENT_NAN_VAL).to(torch.uint8) + else: + final_scale = scales.to(torch.float8_e4m3fn) + + self.weight_scale = final_scale + if self.bits == 8: + self.weight = scaled_tensor.to(torch.float8_e4m3fn) + else: + self.weight_packed = pack_fp4_to_uint8(scaled_tensor) + + if global_scale is not None: + self.weight_global_scale = global_scale.to(torch.float32).to(device) + if input_global_scale is not None: + self.input_global_scale = input_global_scale.to(torch.float32).to(device).reshape([1]) + + # add transform weight + self.register_buffer("hadamard_matrix", w_transform.weight.to(device)) + return + + QuantLinear.pack = _pack_patched + QuantLinear._pack_patched = True diff --git a/auto_round/algorithms/transforms/rotation/transforms.py b/auto_round/algorithms/transforms/rotation/transforms.py new file mode 100644 index 000000000..78dcd0d6b --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/transforms.py @@ -0,0 +1,172 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Concrete ``torch.nn.Module`` implementations of Hadamard transforms. + +:class:`HadamardTransform` – block-diagonal Hadamard (deterministic). +:class:`RandomHadamardTransform` – randomly signed Hadamard. +:func:`build_hadamard_transform` – factory that selects the right class. +""" + +from __future__ import annotations + +import inspect +import math +from typing import Any, Callable, Dict + +import torch +import torch.nn as nn + +from auto_round.algorithms.transforms.rotation.utils.math import ( + deterministic_hadamard_matrix, + random_hadamard_matrix, +) +from auto_round.algorithms.transforms.rotation.utils.matrix import apply_transform_weight + +__all__ = [ + "HadamardTransform", + "RandomHadamardTransform", + "HADAMARDS", + "build_hadamard_transform", +] + + +def _filter_kwargs(fn: Callable, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Return only the keyword arguments accepted by *fn*.""" + accepted = inspect.signature(fn).parameters.keys() + return {k: v for k, v in kwargs.items() if k in accepted} + + +class HadamardTransform(nn.Module): + """Block-diagonal deterministic Hadamard rotation. + + The rotation matrix ``W`` (stored as a frozen ``nn.Parameter``) is + constructed once from :func:`deterministic_hadamard_matrix` and + normalised by ``1 / sqrt(block_size)``. + + Args: + block_size: Size of each Hadamard block (must be a power of 2). + device: Device to place the weight on. + precision: Dtype for the weight tensor. + location: ``"weight"`` (default) or ``"input"`` – controls the + orientation of the multiplication in :meth:`forward`. + module_type: ``type(module)`` passed to + :func:`~utils.matrix.apply_transform_weight`. + inverse: If ``True``, use transposed orientation (for activation + transforms that are the inverse of the weight transform). + """ + + def __init__( + self, + block_size: int | None = 32, + device: torch.device | None = None, + precision: torch.dtype | None = None, + location: str = "weight", + module_type: type[nn.Module] = nn.Linear, + inverse: bool = False, + ) -> None: + super().__init__() + self.size = block_size if block_size is not None else 32 + self.scale = 1.0 / math.sqrt(self.size) + self.location = location + self.module_type = module_type + self.inverse = inverse + self.weight = self._build_weight(self.size, device, precision) + + def _build_weight( + self, + size: int, + device: torch.device | None, + precision: torch.dtype | None, + ) -> nn.Parameter: + data = deterministic_hadamard_matrix(size, precision, device) * self.scale + return nn.Parameter(data, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ori_shape = x.shape + x = x.view(-1, self.size) + out = apply_transform_weight( + self.weight.to(x.device), + x.to(dtype=self.weight.dtype), + self.location, + self.module_type, + ) + return out.to(x.dtype).view(ori_shape) + + +class RandomHadamardTransform(HadamardTransform): + """Randomly signed Hadamard rotation. + + Extends :class:`HadamardTransform` with a seeded random diagonal so the + same seed always produces the same rotation matrix. + + Args: + seed: Integer seed for the internal ``torch.Generator``. + generator: Pre-built ``torch.Generator`` (overrides *seed* if given). + *args, **kwargs: Forwarded to :class:`HadamardTransform`. + """ + + def __init__( + self, + *args: Any, + seed: int | None = None, + generator: torch.Generator | None = None, + **kwargs: Any, + ) -> None: + if generator is not None: + self.generator = generator + else: + self.generator = torch.Generator() + if seed is not None: + self.generator.manual_seed(seed) + super().__init__(*args, **kwargs) + + def _build_weight( + self, + size: int, + device: torch.device | None, + precision: torch.dtype | None, + ) -> nn.Parameter: + data = random_hadamard_matrix(size, precision, device, self.generator) * self.scale + if self.inverse: + data = data.T + return nn.Parameter(data, requires_grad=False) + + +# --------------------------------------------------------------------------- +# Registry and factory +# --------------------------------------------------------------------------- + +#: Maps ``hadamard_type`` strings to their transform classes. +HADAMARDS: dict[str, type[HadamardTransform]] = { + "hadamard": HadamardTransform, + "random_hadamard": RandomHadamardTransform, +} + + +def build_hadamard_transform(hadamard_type: str, **kwargs: Any) -> HadamardTransform: + """Instantiate the correct :class:`HadamardTransform` subclass. + + Args: + hadamard_type: Key into :data:`HADAMARDS` (``"hadamard"`` or + ``"random_hadamard"``). + **kwargs: Forwarded to the transform constructor after filtering + out unsupported keys. + + Returns: + A new :class:`HadamardTransform` instance. + """ + if hadamard_type not in HADAMARDS: + raise ValueError(f"Unknown hadamard_type: {hadamard_type!r}. " f"Available: {sorted(HADAMARDS)}") + cls = HADAMARDS[hadamard_type] + return cls(**_filter_kwargs(cls.__init__, kwargs)) diff --git a/auto_round/algorithms/transforms/rotation/utils/__init__.py b/auto_round/algorithms/transforms/rotation/utils/__init__.py new file mode 100644 index 000000000..0b3bf2eae --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/utils/__init__.py @@ -0,0 +1,2 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 diff --git a/auto_round/experimental/transform/utils/hadamards.safetensors b/auto_round/algorithms/transforms/rotation/utils/hadamards.safetensors similarity index 100% rename from auto_round/experimental/transform/utils/hadamards.safetensors rename to auto_round/algorithms/transforms/rotation/utils/hadamards.safetensors diff --git a/auto_round/algorithms/transforms/rotation/utils/math.py b/auto_round/algorithms/transforms/rotation/utils/math.py new file mode 100644 index 000000000..14b15ce0d --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/utils/math.py @@ -0,0 +1,144 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hadamard matrix construction utilities. + +Provides ``deterministic_hadamard_matrix`` (Sylvester construction) and +``random_hadamard_matrix`` (loaded from a precomputed safetensors file). +""" + +# note that hadamard matrix multiplication reuses code from +# https://github.com/vllm-project/compressed-tensors/blob/main/src/compressed_tensors/transform/utils/hadamard.py + +from __future__ import annotations + +import math +from pathlib import Path + +import torch +from safetensors import safe_open + +__all__ = ["deterministic_hadamard_matrix", "random_hadamard_matrix", "is_pow2"] + +# Precomputed Hadamard matrices for non-power-of-2 sizes. +_HADAMARD_MATRICES_PATH: Path = Path(__file__).parent / "hadamards.safetensors" + + +def deterministic_hadamard_matrix( + size: int, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Construct an ``(size × size)`` Hadamard matrix via Sylvester's construction. + + ``size`` must be a power of 2. + + Adapted from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py + + Args: + size: Order of the matrix; must be a power of 2. + dtype: Output dtype. + device: Output device. + + Returns: + Hadamard tensor of shape ``(size, size)``. + """ + if size <= 0: + raise ValueError("Cannot construct Hadamard matrix with size <= 0") + log2 = int(math.log2(size)) + if size != 2**log2: + raise ValueError("Deterministic Hadamard requires size == 2^n") + + H = torch.tensor([[1]], dtype=dtype, device=device) + for _ in range(log2): + H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H)))) + return H + + +def random_hadamard_matrix( + size: int, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = torch.device("cpu"), + gen: torch.Generator | None = None, +) -> torch.Tensor: + """Create a randomly signed Hadamard matrix of order *size*. + + Supports non-powers-of-2 by reading a precomputed base matrix from + ``hadamards.safetensors`` and composing it with a random ±1 diagonal. + + Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py + + Args: + size: Dimension of the matrix. + dtype: Output dtype. + device: Output device. + gen: Optional seeded ``torch.Generator`` for reproducibility. + + Returns: + Randomly signed Hadamard tensor of shape ``(size, size)``. + """ + Q = torch.randint(0, 2, (size,), generator=gen, dtype=dtype).to(device) + Q = Q * 2 - 1 + return _matmul_hadU(torch.diag(Q)) + + +def is_pow2(n: int) -> bool: + """Return ``True`` iff *n* is a positive power of two.""" + return n > 0 and (n & (n - 1)) == 0 + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _fetch_hadamard_divisor( + n: int, + dtype: torch.dtype, + device: torch.device = torch.device("cpu"), + file_path: Path = _HADAMARD_MATRICES_PATH, +) -> torch.Tensor | None: + """Return the largest precomputed Hadamard divisor ``k`` of *n* such that + ``n / k`` is a power of two, or ``None`` if no such entry exists.""" + open_device = torch.device("cpu") if device.type == "meta" else device + with safe_open(str(file_path), framework="pt", device=str(open_device)) as f: + divisors = sorted((int(key) for key in f.keys()), reverse=True) + for divisor in divisors: + if n % divisor == 0 and is_pow2(n // divisor): + return f.get_tensor(str(divisor)).to(dtype=dtype, device=device) + return None + + +def _matmul_hadU(X: torch.Tensor) -> torch.Tensor: + """Multiply *X* (a diagonal matrix) by the appropriate Hadamard matrix.""" + size = X.size(0) + dtype = X.dtype + device = X.device + + hadK = _fetch_hadamard_divisor(size, dtype, device=device) + if hadK is None: + raise ValueError(f"Cannot construct random Hadamard matrix of size {size}") + K = hadK.size(0) + + inp = X.clone().view(-1, size, 1) + out = inp.clone() + while inp.shape[1] > K: + inp = inp.view(inp.shape[0], inp.shape[1] // 2, 2, inp.shape[2]) + out = out.view(inp.shape) + out[:, :, 0, :] = inp[:, :, 0, :] + inp[:, :, 1, :] + out[:, :, 1, :] = inp[:, :, 0, :] - inp[:, :, 1, :] + out = out.view(inp.shape[0], inp.shape[1], -1) + inp, out = out, inp + assert inp.shape[1] == K + del out + return (hadK.view(1, K, K).to(inp) @ inp).view(X.shape) diff --git a/auto_round/algorithms/transforms/rotation/utils/matrix.py b/auto_round/algorithms/transforms/rotation/utils/matrix.py new file mode 100644 index 000000000..c8c723d83 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/utils/matrix.py @@ -0,0 +1,102 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Linear-algebra helpers for applying weight/activation rotation matrices. + +Note: ``apply_transform_weight`` reuses ideas from +https://github.com/vllm-project/compressed-tensors/blob/main/src/compressed_tensors/transform/utils/matrix.py +""" + +from __future__ import annotations + +import torch + +__all__ = ["apply_transform_weight", "multihead_matmul"] + + +def apply_transform_weight( + transform_weight: torch.Tensor, + value: torch.Tensor, + location: str, + module_type: type[torch.nn.Module], +) -> torch.Tensor: + """Apply *transform_weight* to *value* according to *location*. + + The mathematical relationship for a ``torch.nn.Linear`` layer: + + .. code-block:: none + + y = x W.T (standard linear) + yh = (x V) (U.T W Vi.T).T (rotated linear) + + where *V* is the input-side rotation and *U* the output-side rotation. + + Args: + transform_weight: The rotation matrix to apply. + value: The tensor to rotate (weight or activation). + location: ``"input"`` or ``"weight"``. + module_type: ``type(module)`` – determines how the weight transform + is oriented. + + Returns: + Rotated tensor with the same shape as *value*. + """ + if location == "input": + return multihead_matmul(value, transform_weight) + + if module_type is torch.nn.Linear: + return multihead_matmul(value, transform_weight.T) + + raise NotImplementedError( + f"apply_transform_weight: unsupported location={location!r} " f"with module_type={module_type}" + ) + + +def multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """Block-diagonal matrix multiplication over the last two dimensions. + + Handles the case where *A* and *B* have different sizes in their inner + dimension by treating the smaller matrix as a repeated block-diagonal. + + For example, if ``A.shape[-1] == 2 * B.shape[-2]``, this is equivalent to:: + + A @ block_diag(B, B) + + Args: + A: Left-hand tensor. + B: Right-hand tensor. + + Returns: + Result of the generalised matrix multiplication. + + Raises: + ValueError: If the inner dimensions are not evenly divisible. + """ + a_inner = A.shape[-1] + b_inner = B.shape[-2] + + if a_inner > b_inner: + if a_inner % b_inner != 0: + raise ValueError(f"multihead_matmul: A.shape[-1]={a_inner} is not divisible " f"by B.shape[-2]={b_inner}") + num_heads = a_inner // b_inner + A = A.unflatten(-1, (num_heads, b_inner)) + return (A @ B).flatten(-2, -1) + + if a_inner < b_inner: + if b_inner % a_inner != 0: + raise ValueError(f"multihead_matmul: B.shape[-2]={b_inner} is not divisible " f"by A.shape[-1]={a_inner}") + num_heads = b_inner // a_inner + B = B.unflatten(-2, (num_heads, a_inner)) + return (A @ B).flatten(-3, -2) + + return A @ B diff --git a/auto_round/algorithms/transforms/rotation/utils/triton/__init__.py b/auto_round/algorithms/transforms/rotation/utils/triton/__init__.py new file mode 100644 index 000000000..0b3bf2eae --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/utils/triton/__init__.py @@ -0,0 +1,2 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 diff --git a/auto_round/algorithms/transforms/rotation/utils/triton/mxfp4.py b/auto_round/algorithms/transforms/rotation/utils/triton/mxfp4.py new file mode 100644 index 000000000..c26413248 --- /dev/null +++ b/auto_round/algorithms/transforms/rotation/utils/triton/mxfp4.py @@ -0,0 +1,192 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Refer code here: +# https://github.com/IST-DASLab/FP-Quant/blob/master/inference_lib/src/fp_quant/module/triton/mxfp4.py + +import torch +import triton # pylint: disable=E0401 +import triton.language as tl # pylint: disable=E0401 + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 32 * 32}), + triton.Config({"BLOCK_SIZE": 64 * 32}), + triton.Config({"BLOCK_SIZE": 128 * 32}), + triton.Config({"BLOCK_SIZE": 256 * 32}), + triton.Config({"BLOCK_SIZE": 512 * 32}), + ], + key=[], +) +@triton.jit +def mxfp4_forward_kernel( + x_ptr, + hadamard_matrix_ptr, + output_ptr, + clip_mask_ptr, + n_elements: tl.constexpr, + hadamard_dim: tl.constexpr, + group_size: tl.constexpr, + gaussian_scale: tl.constexpr, + quest: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim) + hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim) + + # load x + pid = tl.program_id(0) + start_idx = pid * BLOCK_SIZE + offsets = start_idx + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x_flat = tl.load(x_ptr + offsets, mask=mask) + + # hadamard transform + x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim)) + x_had = tl.dot(x, hadamard_matrix) + + # group + x_had_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size)) + + # scale + # quest=True: per-group Gaussian-based scale = gaussian_scale * std + # quest=False: per-group max-abs-based scale, adjusted to FP4 range + if quest: + mean_squared = tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size + mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size + std = tl.sqrt(mean_squared - mean * mean) + scales = gaussian_scale * std + 1e-8 + shared_exps = tl.exp2(tl.floor(tl.log2(scales))) + x_had_scaled = x_had_grouped / shared_exps + else: + scales = tl.max(tl.abs(x_had_grouped), axis=-1, keep_dims=True) + shared_exps = tl.exp2(tl.floor(tl.log2(scales)) - 2) / (3 / 4) + x_had_scaled = x_had_grouped / shared_exps + + # quantize + # Map abs(x) to FP4 levels {0, 0.5, 1, 1.5, 2, 3, 4, 6} + x_had_scaled_abs = tl.abs(x_had_scaled) + x_had_scaled_sign = tl.where( + x_had_scaled > 0, + 1, + -1, + ) + + x_fp4 = ( + tl.where( + x_had_scaled_abs > 5, + 6, + tl.where( + x_had_scaled_abs > 3.5, + 4, + tl.where( + x_had_scaled_abs > 2.5, + 3, + tl.where( + x_had_scaled_abs > 1.75, + 2, + tl.where( + x_had_scaled_abs > 1.25, + 1.5, + tl.where( + x_had_scaled_abs > 0.75, + 1, + tl.where( + x_had_scaled_abs > 0.25, + 0.5, + 0, + ), + ), + ), + ), + ), + ), + ) + * x_had_scaled_sign + ) + if clip_mask_ptr is not None: + tl.store( + clip_mask_ptr + offsets, + tl.reshape(x_had_scaled_abs < 6, (BLOCK_SIZE,)), + mask=mask, + ) + + # dequantize + x_dequantized = x_fp4 * shared_exps + + # Reshape back to flat form for storage + x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,)) + + # store + tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask) + + +@torch.compiler.disable() +def mxfp4_forward_kernel_wrapper( + x, + hadamard_matrix, + return_clip_mask=False, + quest=False, + gaussian_scale=3 / 4, +): + """ + Refer code here: + https://github.com/IST-DASLab/FP-Quant/blob/master/inference_lib/src/fp_quant/module/triton/mxfp4.py + Apply Hadamard transform + group-wise FP4 quantize/dequantize on x. + + Note: + The output is still in the Hadamard-transformed space (no inverse Hadamard is applied). + """ + # Pick a device — we require CUDA + device = x.device + if device.type != "cuda": + raise RuntimeError( + f"mxfp4_forward_kernel_wrapper requires a CUDA tensor for 'x', " + f"but got device '{device.type}'. Please move inputs to CUDA before calling." + ) + + # Ensure hadamard_matrix is on the same CUDA device + if hadamard_matrix.device != device: + hadamard_matrix = hadamard_matrix.to(device) + + # Make sure inputs are contiguous + x = x.contiguous() + hadamard_matrix = hadamard_matrix.contiguous() + + # Create output tensors on CUDA + output = torch.empty_like(x, device=device) + if return_clip_mask: + clip_mask = torch.empty_like(x, dtype=torch.bool, device=device).contiguous() + else: + clip_mask = None + + # Get total number of elements and calculate grid for launching the kernel + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + # Launch kernel – no need for `with torch.device(...)` + mxfp4_forward_kernel[grid]( + x_ptr=x, + hadamard_matrix_ptr=hadamard_matrix, + output_ptr=output, + clip_mask_ptr=clip_mask, + n_elements=n_elements, + hadamard_dim=hadamard_matrix.shape[-1], + group_size=32, + gaussian_scale=gaussian_scale, + quest=quest, + ) + + return output, clip_mask diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 3f46d40fc..b2bec2651 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -17,6 +17,7 @@ import torch +import auto_round.envs as envs from auto_round.compressors import ( AdamCompressor, BaseCompressor, @@ -33,6 +34,9 @@ if TYPE_CHECKING: from auto_round.auto_scheme.gen_auto_scheme import AutoScheme +# Default to new architecture; set AR_DISABLE_NEW_ARCH=true/1 to force old architecture. +NEW_ARCH = not envs.AR_DISABLE_NEW_ARCH + class AutoRound: """Automatic weight rounding (Signed Gradient Descent) for LLM quantization @@ -53,7 +57,7 @@ class AutoRound: enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers. """ - SKIP_ARGS = ("local_args", "kwargs", "cls", "model_cls", "dynamic_compressor", "extra_config", "enable_adam") + SKIP_ARGS = ("local_args", "kwargs", "cls", "model_cls", "dynamic_compressor", "extra_config") bits: int | None group_size: int | tuple | None @@ -160,6 +164,11 @@ def __new__( local_args = {k: v for k, v in locals().items() if k not in cls.SKIP_ARGS} + if NEW_ARCH: + from auto_round.compressors_new.entry import AutoRoundCompatible + + return AutoRoundCompatible(**local_args, **kwargs) + model_cls = [] has_multimodal_assets = kwargs.get("processor") is not None or kwargs.get("image_processor") is not None @@ -238,21 +247,23 @@ def _sampling_inputs( for key in input_others.keys(): if "positional_inputs" in key: continue - if (key not in share_cache_keys or len(indices) == 1) and not isinstance( - input_others[key], (str, bool, type(None)) - ): - current_input_others[key] = None - if input_others[key] is not None: - current_input_others[key] = [input_others[key][i] for i in indices] - if len(indices) == 1: - current_input_others[key] = current_input_others[key][0] - else: - try: - current_input_others[key] = torch.cat(current_input_others[key], dim=0) - except TypeError as err: - logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") - else: + # Shared cache keys (e.g. position_embeddings, position_ids, cache_position) are stored + # directly as-is (not wrapped in a per-sample list) when batch_size > 1. Indexing such + # values by sample index would incorrectly decompose them (e.g. (cos, sin)[0] == cos). + # Always pass them through unchanged. + if key in share_cache_keys or isinstance(input_others[key], (str, bool, type(None))): current_input_others[key] = input_others[key] + elif input_others[key] is not None: + current_input_others[key] = [input_others[key][i] for i in indices] + if len(indices) == 1: + current_input_others[key] = current_input_others[key][0] + else: + try: + current_input_others[key] = torch.cat(current_input_others[key], dim=0) + except TypeError as err: + logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") + else: + current_input_others[key] = None return current_input_ids, current_input_others diff --git a/auto_round/calibration/__init__.py b/auto_round/calibration/__init__.py new file mode 100644 index 000000000..14a492441 --- /dev/null +++ b/auto_round/calibration/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/auto_round/calibration/utils.py b/auto_round/calibration/utils.py new file mode 100644 index 000000000..523561aef --- /dev/null +++ b/auto_round/calibration/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from auto_round.modeling.fused_moe.replace_modules import materialize_model_, safe_to_cpu_ +from auto_round.utils import ( + is_quantized_input_module, +) + + +def _infer_last_cache_name(block_names, layer_names=None, requested_last_cache_name=None): + """The latest required cache layer for early-stop forward. + + If there are multiple cache targets, return ``None`` and let runtime + hooks stop only after all targets are observed in real forward execution. + """ + if layer_names is None: + layer_names = [] + + if requested_last_cache_name is not None: + return requested_last_cache_name + + cache_targets = list(block_names) + list(layer_names) + if len(cache_targets) == 1: + return cache_targets[0] + + # return None here to enable the logic in _should_stop_cache_forward + return None + + +def _update_inputs(inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tensor]: + from auto_round.context.model import ModelContext + + model_context = ModelContext() + if model_context.is_diffusion: + # flux transformer model's blocks will update hidden_states and encoder_hidden_states + input_id_str = [key for key in inputs.keys() if "hidden_state" in key] + if q_inputs is not None: + q_inputs = {k: q_inputs.pop(k, None) for k in input_id_str} + return inputs, q_inputs + + keys = inputs.keys() + input_id_str = [key for key in keys if key.startswith("hidden_state")] + if len(input_id_str) != 1: + raise RuntimeError( + "hidden_states arg mismatch error," "please raise an issue in https://github.com/intel/auto-round/issues" + ) + inputs["input_ids"] = inputs.pop(input_id_str[0], None) + if q_inputs is not None: + q_inputs = q_inputs.pop(input_id_str[0], None) + return inputs, q_inputs diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 985ccfc46..0503c8235 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -33,6 +33,7 @@ from transformers import AutoConfig, set_seed from auto_round import envs +from auto_round.algorithms.quantization.sign_round.sign_sgd import SignSGD from auto_round.auto_scheme.gen_auto_scheme import AutoScheme from auto_round.compressors.shard_writer import shard_writer from auto_round.compressors.utils import ( @@ -71,7 +72,6 @@ preset_name_to_scheme, scheme_to_preset_name, ) -from auto_round.sign_sgd import SignSGD from auto_round.special_model_handler import get_predefined_fixed_attr, get_predefined_ignore_layers, update_module from auto_round.utils import ( INNER_SUPPORTED_LAYER_TYPES, @@ -1865,6 +1865,14 @@ def _adjust_immediate_packing_and_saving(self): if self.low_cpu_mem_usage and self.is_immediate_packing: self.is_immediate_saving = True + if self.low_cpu_mem_usage and not self.is_immediate_packing: + logger.info( + "`low_cpu_mem_usage` is only supported when `immediate_packing` is True. " + "Setting `low_cpu_mem_usage` to False." + ) + self.low_cpu_mem_usage = False + self.is_immediate_saving = False + if self.low_cpu_mem_usage and self.is_immediate_packing: if formats[0].is_gguf(): logger.warning( @@ -2679,10 +2687,10 @@ def post_process_cache_data(batch_size, data, data_name): Processed data or None """ new_data = data - if batch_size <= 1: - return new_data if data_name in self.shared_cache_keys: return None + if batch_size <= 1: + return new_data if "alibi" in data_name: if isinstance(data, torch.Tensor): alibi = data @@ -2736,7 +2744,7 @@ def forward(m, hidden_states=None, *positional_inputs, **kwargs): continue if key not in self.inputs[name].keys(): # initialization data = to_device(kwargs[key], device=torch.device("cpu")) - if data is None or (self.batch_size > 1 and key in self.shared_cache_keys): + if data is None or key in self.shared_cache_keys: self.inputs[name][key] = data continue if self.batch_size <= 1: @@ -3743,21 +3751,23 @@ def _sampling_inputs( for key in input_others.keys(): if "positional_inputs" in key: continue - if (key not in share_cache_keys or len(indices) == 1) and not isinstance( - input_others[key], (str, bool, type(None)) - ): - current_input_others[key] = None - if input_others[key] is not None: - current_input_others[key] = [input_others[key][i] for i in indices] - if len(indices) == 1: - current_input_others[key] = current_input_others[key][0] - else: - try: - current_input_others[key] = torch.cat(current_input_others[key], dim=0) - except TypeError as err: - logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") - else: + # Shared cache keys (e.g. position_embeddings, position_ids, cache_position) are stored + # directly as-is (not wrapped in a per-sample list) when batch_size > 1. Indexing such + # values by sample index would incorrectly decompose them (e.g. (cos, sin)[0] == cos). + # Always pass them through unchanged. + if key in share_cache_keys or isinstance(input_others[key], (str, bool, type(None))): current_input_others[key] = input_others[key] + elif input_others[key] is not None: + current_input_others[key] = [input_others[key][i] for i in indices] + if len(indices) == 1: + current_input_others[key] = current_input_others[key][0] + else: + try: + current_input_others[key] = torch.cat(current_input_others[key], dim=0) + except TypeError as err: + logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") + else: + current_input_others[key] = None return current_input_ids, current_input_others diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 7b6f04545..fb6c8a0e6 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -28,7 +28,7 @@ from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, GGUF_CONFIG, GGUF_INNER_CONFIG, QK_K, ModelType from auto_round.logger import logger from auto_round.schemes import QuantizationScheme, get_gguf_scheme, preset_name_to_scheme -from auto_round.utils import check_to_quantized, to_standard_regex +from auto_round.utils import check_to_quantized, infer_bits_by_data_type, to_standard_regex class BackendDataType(str, Enum): @@ -225,30 +225,6 @@ def collect_best_params(block, cache_device="cpu"): return params -def infer_bits_by_data_type(data_type: str): - """Infer bits by data_type - - Args: - data_type (str): data_type - - Returns: - int: bits inferred by data_type, None means cannot infer correct bits by data_type - """ - from auto_round.utils import SUPPORTED_DTYPES - - if data_type is None: - return 16 - for supported_dtype in SUPPORTED_DTYPES: - if data_type.startswith(supported_dtype) and len(data_type) > len(supported_dtype): - ##first check the following two bits - suc_2str = data_type[len(supported_dtype) : len(supported_dtype) + 2] - if str.isdigit(suc_2str): - return int(suc_2str) - if str.isdigit(data_type[len(supported_dtype)]): - return int(data_type[len(supported_dtype)]) - return None - - def _get_safetensor_layer_names_not_in_model(model, all_module_names: list) -> list: """Collect layer names from safetensor files that are not loaded into the model. diff --git a/auto_round/compressors_new/__init__.py b/auto_round/compressors_new/__init__.py new file mode 100644 index 000000000..66dbe5de9 --- /dev/null +++ b/auto_round/compressors_new/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lazy imports to avoid circular dependencies +# Users should import from specific modules instead of this __init__.py + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from auto_round.compressors_new.calib import CalibCompressor, CalibratedRTNCompressor + from auto_round.compressors_new.entry import AutoRoundCompatible, AutoRound + from auto_round.compressors_new.zero_shot import ZeroShotCompressor + +__all__ = [ + "AutoRound", + "CalibCompressor", + "CalibratedRTNCompressor", + "ZeroShotCompressor", + "AutoRoundCompatible", +] + + +def __getattr__(name): + """Lazy import to avoid circular dependencies.""" + if name == "AutoRound" or name == "AutoRoundCompatible": + from auto_round.compressors_new.entry import AutoRound, AutoRoundCompatible + + if name == "AutoRound": + return AutoRound + return AutoRoundCompatible + elif name in ("CalibCompressor", "CalibratedRTNCompressor"): + from auto_round.compressors_new.calib import CalibCompressor, CalibratedRTNCompressor + + return { + "CalibCompressor": CalibCompressor, + "CalibratedRTNCompressor": CalibratedRTNCompressor, + }[name] + elif name == "ZeroShotCompressor": + from auto_round.compressors_new.zero_shot import ZeroShotCompressor + + return ZeroShotCompressor + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/auto_round/compressors_new/architecture_visualization.py b/auto_round/compressors_new/architecture_visualization.py new file mode 100644 index 000000000..7ff020733 --- /dev/null +++ b/auto_round/compressors_new/architecture_visualization.py @@ -0,0 +1,332 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 +""" +New Architecture Visualization - Mixin Pattern Combination Table + +Demonstrates all possible combinations of model types and compression algorithms. +""" + + +def print_architecture_table(): + """Print architecture combination table""" + + print("\n" + "=" * 110) + print("AutoRound New Architecture - Mixin Pattern Combination Table") + print("=" * 110 + "\n") + + print(f"{'Model Type':<15} {'Config Type':<20} {'AutoRound (dynamic class)':<40} {'Base classes':<35}") + print("-" * 110) + + # LLM combinations + print(f"{'LLM':<15} {'SignRoundConfig':<20} {'CalibCompressor':<40} {'CalibCompressor':<35}") + print(f"{'LLM':<15} {'RTNConfig':<20} {'CalibratedRTNCompressor':<40} {'CalibratedRTNCompressor':<35}") + print(f"{'LLM':<15} {'RTNConfig':<20} {'ZeroShotCompressor':<40} {'ZeroShotCompressor':<35}") + + print() + + # MLLM combinations (dynamic classes created in entry.py) + print(f"{'MLLM':<15} {'SignRoundConfig':<20} {'MLLMCalibCompressor':<40} {'MLLMMixin + CalibCompressor':<35}") + print( + f"{'MLLM':<15} {'RTNConfig':<20} {'MLLMCalibratedRTNCompressor':<40} " + f"{'MLLMMixin + CalibratedRTNCompressor':<35}" + ) + print(f"{'MLLM':<15} {'RTNConfig':<20} {'MLLMZeroShotCompressor':<40} {'MLLMMixin + ZeroShotCompressor':<35}") + + print() + + # Diffusion combinations (dynamic classes created in entry.py) + print( + f"{'Diffusion':<15} {'SignRoundConfig':<20} {'DiffusionCalibCompressor':<40} " + f"{'DiffusionMixin + CalibCompressor':<35}" + ) + print( + f"{'Diffusion':<15} {'RTNConfig':<20} {'DiffusionCalibratedRTNCompressor':<40} " + f"{'DiffusionMixin + CalibratedRTNCompressor':<35}" + ) + print( + f"{'Diffusion':<15} {'RTNConfig':<20} {'DiffusionZeroShotCompressor':<40} " + f"{'DiffusionMixin + ZeroShotCompressor':<35}" + ) + + print("\n" + "=" * 110 + "\n") + + +def print_mixin_explanation(): + """Print Mixin pattern explanation""" + + print("=" * 110) + print("Mixin Pattern Explanation") + print("=" * 110 + "\n") + + print("✨ Core Components:") + print("-" * 110) + print(" 1. MLLMMixin - MLLM features (processor, template, quant_nontext_module, etc.)") + print(" 2. DiffusionMixin - Diffusion features (pipeline loading, guidance_scale, etc.)") + print(" 3. CalibCompressor - AutoRoundCompatible: gradient-based calibration quantization") + print(" 4. CalibratedRTNCompressor - RTN with importance-matrix (imatrix) or act calibration") + print(" 5. ZeroShotCompressor - Zero-shot RTN (no calibration data needed)") + + print("\n🎯 Combination Approach:") + print("-" * 110) + print(" Dynamic classes created on-the-fly inside AutoRound.__new__():") + print(" class MLLMCalibCompressor(MLLMMixin, CalibCompressor): pass") + print(" class MLLMCalibratedRTNCompressor(MLLMMixin, CalibratedRTNCompressor): pass") + print(" class MLLMZeroShotCompressor(MLLMMixin, ZeroShotCompressor): pass") + + print("\n💡 Advantages:") + print("-" * 110) + print(" ✓ Flexible Combination: Any model type can be combined with any compression algorithm") + print(" ✓ Code Reuse: Mixin code is written once and reused across all compression algorithms") + print(" ✓ Clear Separation: Model-specific logic (Mixin) and compression algorithm are independent") + print(" ✓ Easy Extension: Add new model types without touching existing compressor code") + + print("\n" + "=" * 110 + "\n") + + +def print_post_init_flow(): + """Print the post_init execution flow""" + + print("=" * 110) + print("BaseCompressor.post_init() Execution Flow") + print("=" * 110 + "\n") + + print(""" +BaseCompressor.post_init() +│ +├─ Step 1: Resolve formats (str → list[OutputFormat]) +│ └─ get_formats(self.formats, self) +│ +├─ Step 2: Apply format-specific model patches +│ └─ model_context.apply_patches(formats) +│ ├─ _patch_custom_moe_modules() # e.g. Qwen3VL MoE top_k fix +│ ├─ update_module(model, formats) # add gguf_pack_linear etc. +│ └─ assign global_name to all modules +│ +├─ Step 3: Setup quantizer on the patched model +│ └─ quantizer = BaseQuantizers.from_config(config) +│ └─ quantizer.post_init() +│ ├─ get ModelContext / CompressContext singletons +│ ├─ _parse_scheme() → resolve final quant attrs +│ ├─ get_block_names(quant_vision=quant_nontext_module) +│ ├─ find_matching_blocks() → quant_block_list +│ └─ back-fill to_quant_block_names if it was None +│ +└─ Step 4: Setup device map, torch compile, offloader + """) + + print("=" * 110 + "\n") + + +def print_usage_examples(): + """Print usage examples""" + + print("=" * 110) + print("Usage Examples") + print("=" * 110 + "\n") + + print("Example 1: MLLM + AutoRoundCompatible (gradient-based)") + print("-" * 110) + print(""" +from auto_round.compressors_new.entry import AutoRound +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig + +config = SignRoundConfig(scheme="W4A16", iters=200, nsamples=128) +compressor = AutoRound( + config=config, + model="/models/Qwen2-VL-2B-Instruct", + processor=processor, + template="qwen2_vl", + quant_nontext_module=False, # set True to also quantize vision encoder +) +# Dynamically creates: class MLLMCalibCompressor(MLLMMixin, CalibCompressor) + """) + + print("\nExample 2: MLLM + RTN with imatrix") + print("-" * 110) + print(""" +from auto_round.algorithms.quantization.rtn.config import RTNConfig + +config = RTNConfig(scheme="W4A16") +compressor = AutoRound( + config=config, + model="/models/Qwen2-VL-2B-Instruct", + format="gguf_k", # gguf_k triggers CalibratedRTNCompressor + processor=processor, +) +# Dynamically creates: class MLLMCalibratedRTNCompressor(MLLMMixin, CalibratedRTNCompressor) + """) + + print("\nExample 3: Diffusion + AutoRoundCompatible") + print("-" * 110) + print(""" +config = SignRoundConfig(scheme="W4A16", iters=200) +compressor = AutoRound( + config=config, + model="/models/stable-diffusion-2-1", + guidance_scale=7.5, +) +# Dynamically creates: class DiffusionCalibCompressor(DiffusionMixin, CalibCompressor) + """) + + print("\n" + "=" * 110 + "\n") + + +def print_mro_example(): + """Print MRO (Method Resolution Order) example""" + + print("=" * 110) + print("Method Resolution Order (MRO) Example") + print("=" * 110 + "\n") + + print("For class MLLMCalibCompressor(MLLMMixin, CalibCompressor):") + print("-" * 110) + print(""" +MLLMCalibCompressor (dynamic, created in AutoRound.__new__) + └─> MLLMMixin + └─> CalibCompressor + └─> BaseCompressor + └─> object + +Execution order when calling __init__(): + 1. MLLMCalibCompressor.__init__() → not defined, falls through + 2. MLLMMixin.__init__() + - Save MLLM-specific attrs: processor, template, quant_nontext_module, … + - kwargs.setdefault("quant_nontext_module", quant_nontext_module) + - Call super().__init__() → enters CalibCompressor + 3. CalibCompressor.__init__() → BaseCompressor.__init__() + - pops quant_nontext_module from kwargs + - Creates ModelContext(…, quant_nontext_module=quant_nontext_module) + - ModelContext.__init__ eagerly loads the model + - Creates CompressContext singleton + +MLLMCalibCompressor instance has: + ✓ MLLM features from MLLMMixin (processor, template, calib() override) + ✓ Calibration compression from CalibCompressor + ✓ Model/context management from BaseCompressor + """) + + print("=" * 110 + "\n") + + +def print_decision_tree(): + """Print decision tree""" + + print("=" * 110) + print("AutoRound Creation Decision Tree") + print("=" * 110 + "\n") + + print(""" +AutoRound.__new__(config, model, format, **kwargs) +│ +├─ Step 1: Detect model type +│ model_type = detect_model_type(model) +│ ├─ is_diffusion_model() → "diffusion" +│ ├─ is_mllm_model() → "mllm" +│ └─ else → "llm" +│ +├─ isinstance(config, SignRoundConfig) +│ ├─ model_type == "mllm" +│ │ └─> class MLLMCalibCompressor(MLLMMixin, CalibCompressor) +│ ├─ model_type == "diffusion" +│ │ └─> class DiffusionCalibCompressor(DiffusionMixin, CalibCompressor) +│ └─ model_type == "llm" +│ └─> CalibCompressor +│ +└─ isinstance(config, RTNConfig) + │ + ├─ enable_imatrix OR needs_act_calib → CalibratedRTNCompressor path + │ ├─ gguf_k format → enable_imatrix = True + │ ├─ symmetric int RTN → enable_imatrix = True + │ ├─ static activation quantization → needs_act_calib = True + │ │ + │ ├─ model_type == "mllm" + │ │ └─> class MLLMCalibratedRTNCompressor(MLLMMixin, CalibratedRTNCompressor) + │ ├─ model_type == "diffusion" + │ │ └─> class DiffusionCalibratedRTNCompressor(DiffusionMixin, CalibratedRTNCompressor) + │ └─ model_type == "llm" + │ └─> CalibratedRTNCompressor + │ + └─ else (zero-shot) → ZeroShotCompressor path + ├─ model_type == "mllm" + │ └─> class MLLMZeroShotCompressor(MLLMMixin, ZeroShotCompressor) + ├─ model_type == "diffusion" + │ └─> class DiffusionZeroShotCompressor(DiffusionMixin, ZeroShotCompressor) + └─ model_type == "llm" + └─> ZeroShotCompressor + """) + + print("=" * 110 + "\n") + + +def print_quantizer_interface(): + """Print the BaseQuantizers interface contract""" + + print("=" * 110) + print("BaseQuantizers Interface - Name-based quantize_block / quantize_layer") + print("=" * 110 + "\n") + + print(""" +All quantizers use module *names* (str) instead of module objects. +The module is retrieved internally via get_module(model, name). + + BaseQuantizers (abstract) + ├─ quantize_block(block_name: Union[str, list[str]], input_ids, input_others, **kwargs) + │ str → get_module(model, block_name) + │ list[str] → WrapperMultiblock([get_module(model, n) for n in block_name]) + │ (used when nblocks > 1 in CalibCompressor) + │ + └─ quantize_layer(layer_name: str, **kwargs) + → get_module(model, layer_name) + + Implementations: + ├─ RTNQuantizer.quantize_block(block_name: str) + ├─ OptimizedRTNQuantizer.quantize_block(block_name: str, input_ids, input_others) + └─ SignRoundQuantizer.quantize_block(block_name: Union[str, list[str]], input_ids, input_others) + """) + + print("=" * 110 + "\n") + + +def main(): + """Run all visualizations""" + + print_architecture_table() + print_mixin_explanation() + print_post_init_flow() + print_usage_examples() + print_mro_example() + print_decision_tree() + print_quantizer_interface() + + print("=" * 110) + print("🎉 New architecture supports 9 combinations (3 model types × 3 compression algorithms)") + print(" CalibratedRTNCompressor (was ImatrixCompressor) lives in calib.py") + print("=" * 110) + + +if __name__ == "__main__": + main() + + print(f"{'LLM':<15} {'RTNConfig':<20} {'RTN (zero-shot)':<20} {'ZeroShotCompressor':<35}") + + print() + + # MLLM combinations + print(f"{'MLLM':<15} {'SignRoundConfig':<20} {'AutoRoundCompatible':<20} {'MLLMCalibCompressor':<35}") + print(f"{'':<15} {'':<20} {'':<20} {' = MLLMMixin + CalibCompressor':<35}") + print(f"{'MLLM':<15} {'RTNConfig':<20} {'RTN + imatrix':<20} {'MLLMImatrixCompressor':<35}") + print(f"{'':<15} {'':<20} {'':<20} {' = MLLMMixin + ImatrixCompressor':<35}") + print(f"{'MLLM':<15} {'RTNConfig':<20} {'RTN (zero-shot)':<20} {'MLLMZeroShotCompressor':<35}") + print(f"{'':<15} {'':<20} {'':<20} {' = MLLMMixin + ZeroShotCompressor':<35}") + + print() + + # Diffusion combinations + print(f"{'Diffusion':<15} {'SignRoundConfig':<20} {'AutoRoundCompatible':<20} {'DiffusionCalibCompressor':<35}") + print(f"{'':<15} {'':<20} {'':<20} {' = DiffusionMixin + CalibCompressor':<35}") + print(f"{'Diffusion':<15} {'RTNConfig':<20} {'RTN + imatrix':<20} {'DiffusionImatrixCompressor':<35}") + print(f"{'':<15} {'':<20} {'':<20} {' = DiffusionMixin + ImatrixCompressor':<35}") + print(f"{'Diffusion':<15} {'RTNConfig':<20} {'RTN (zero-shot)':<20} {'DiffusionZeroShotCompressor':<35}") + print(f"{'':<15} {'':<20} {'':<20} {' = DiffusionMixin + ZeroShotCompressor':<35}") + + print("\n" + "=" * 100 + "\n") diff --git a/auto_round/compressors_new/base.py b/auto_round/compressors_new/base.py new file mode 100644 index 000000000..37f5726e8 --- /dev/null +++ b/auto_round/compressors_new/base.py @@ -0,0 +1,1310 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import gc +import os +import sys +from dataclasses import asdict, dataclass, fields +from typing import Any, Optional, Union + +import torch +from transformers import set_seed + +from auto_round.algorithms.alg_config import AlgConfig +from auto_round.algorithms.quantization import BaseQuantizers, QuantizationConfig +from auto_round.algorithms.transforms import ( + BaseRotationConfig, + apply_rotation, +) +from auto_round.compressors.utils import is_mx_fp, is_nv_fp +from auto_round.compressors_new.shard_writer import ShardWriter +from auto_round.compressors_new.utils import _get_save_folder_name, block_forward, set_layer_config +from auto_round.context.compress import CompressContext +from auto_round.context.model import ModelContext +from auto_round.formats import OutputFormat, get_formats +from auto_round.logger import logger +from auto_round.schemes import ( + QuantizationScheme, + _handle_special_schemes, + _parse_scheme, + get_gguf_scheme, + preset_name_to_scheme, +) +from auto_round.special_model_handler import get_predefined_fixed_attr, get_predefined_ignore_layers +from auto_round.utils import ( + INNER_SUPPORTED_LAYER_TYPES, + SUPPORTED_LAYER_TYPES, + TORCH_VERSION_AT_LEAST_2_6, + compile_func, + compress_layer_names, + convert_dtype_str2torch, + extract_block_names_to_str, + find_matching_blocks, + get_block_names, + is_debug_mode, + is_hpex_available, + is_quantized_input_module, + memory_monitor, +) +from auto_round.utils.device import ( + _force_trim_malloc, + get_major_device, + patch_xpu_sdpa_drop_causal_mask, + set_non_auto_device_map, +) +from auto_round.utils.offload import OffloadManager +from auto_round.wrapper import wrapper_block + + +@dataclass +class SerializedCompressorConfig: + bits: Optional[int] = None + act_bits: Optional[int] = None + data_type: Optional[str] = None + act_data_type: Optional[str] = None + group_size: Optional[int] = None + act_group_size: Optional[int] = None + sym: Optional[bool] = None + act_sym: Optional[bool] = None + act_dynamic: Optional[bool] = None + amp: Optional[bool] = None + batch_size: Optional[int] = None + enable_minmax_tuning: Optional[bool] = True + enable_norm_bias_tuning: Optional[bool] = False + enable_quanted_input: Optional[bool] = True + gradient_accumulate_steps: Optional[int] = None + iters: Optional[int] = None + lr: Optional[float] = None + low_gpu_mem_usage: Optional[bool] = None + minmax_lr: Optional[float] = None + nsamples: Optional[int] = None + quant_block_list: Optional[list[str]] = None + regex_config: Optional[dict[str, Any]] = None + scale_dtype: Optional[str] = None + seqlen: Optional[int] = None + supported_types: Optional[list[str]] = SUPPORTED_LAYER_TYPES + static_attention_dtype: Optional[str] = None + static_kv_dtype: Optional[str] = None + super_bits: Optional[int] = None + super_group_size: Optional[int] = None + to_quant_block_names: Optional[list[str]] = None + transform_configs: Optional[list[dict[str, Any]]] = None + + +class BaseCompressor(object): + need_calib: bool = True + compress_context: CompressContext = None + model_context: ModelContext = None + shard_writer: ShardWriter = None + supported_types = SUPPORTED_LAYER_TYPES + inner_supported_types = INNER_SUPPORTED_LAYER_TYPES + + # ── Scheme state (populated during resolve_scheme / _scheme_post_init) ── + is_auto_scheme: bool = False + orig_scheme = None + scheme = None + scale_dtype = None + layer_config = None + has_qlayer_outside_block: bool = False + regex_config: dict = None + quant_block_list: list = None + to_quant_block_names = None + ignore_layers: str = "" + quant_lm_head: bool = False + _scheme_resolved: bool = False + scheme_generator = None + + def __init__( + self, + config: Union[AlgConfig, list[AlgConfig]], + model: Union[torch.nn.Module, str], + tokenizer=None, + platform="hf", + format=None, + scheme="W4A16", + low_gpu_mem_usage: bool = False, + device_map: Union[str, torch.device, int, dict] = 0, + enable_torch_compile: bool = False, + seed: int = 42, + low_cpu_mem_usage: bool = True, + layer_config=None, + nsamples: int = None, + seqlen: int = None, + **kwargs, + ): + self.quantize_config = None + self.transform_configs: list[BaseRotationConfig] = [] + _config_list = config if isinstance(config, list) else [config] + for _cfg in _config_list: + if isinstance(_cfg, QuantizationConfig): + self.quantize_config = _cfg + elif isinstance(_cfg, BaseRotationConfig): + self.transform_configs.append(_cfg) + assert self.quantize_config is not None, "QuantizationConfig is required for Compressor" + + # Compressor-level calibration/layer params (do not live in QuantizationConfig). + self.layer_config = layer_config + self.nsamples = nsamples if nsamples is not None else 128 + self.seqlen = seqlen if seqlen is not None else 2048 + + # Scheme is passed directly to the compressor, not stored in QuantizationConfig. + self.scheme = scheme + + # TODO: refactor calibration + self.calibration = None + + self.formats = format + + # Extra/legacy kwargs for backward compatibility + # Major version releases may pack them with extra configuration options + amp = kwargs.pop("amp", True) + nblocks = kwargs.pop("nblocks", 1) + disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True) + enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) + + self._offloader = OffloadManager(enabled=low_cpu_mem_usage, mode="offload", offload_dir_prefix="compressor") + + # Model related + model_dtype = kwargs.pop("model_dtype", None) + trust_remote_code = kwargs.pop("trust_remote_code") if "trust_remote_code" in kwargs else True + quant_nontext_module = kwargs.pop("quant_nontext_module", False) + + self.static_attention_dtype = kwargs.pop("static_attention_dtype", None) + # Attention static dtype + if self.static_attention_dtype is not None: + logger.warning("The static attention dtype is experimental and currently has limited support.") + # KV cache, this one does not affect tuning but will collect some infos during tuning + self.static_kv_dtype = kwargs.pop("static_kv_dtype", None) + if self.static_kv_dtype is not None: + logger.warning("The static kv is experimental and currently has limited support.") + + if kwargs: + logger.warning( + f"unrecognized keys {list(kwargs.keys())} were passed. " + "Please check them. If you use old api, just ignore this warning." + ) + if "CUBLAS_WORKSPACE_CONFIG" not in os.environ: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + # Deprecated, default not to use torch.use_deterministic_algorithms + if not disable_deterministic_algorithms or enable_deterministic_algorithms: + if not disable_deterministic_algorithms: + logger.warning( + "default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated," + " please use enable_deterministic_algorithms instead. " + ) + + torch.use_deterministic_algorithms(True, warn_only=False) + else: + torch.use_deterministic_algorithms(True, warn_only=True) + + # XPU SDPA workaround: drop pure causal masks so FLASH backend is used, + # and set torch.use_deterministic_algorithms(False) + # instead of MATH (avoids ~10x peak-VRAM blow-up during block tuning). + patch_xpu_sdpa_drop_causal_mask() + + device = kwargs.pop("device", None) + if device is not None: + logger.warning("`device` is deprecated, please use `device_map` instead") + + # Tuning hyperparameters + self.seed = seed + set_seed(self.seed) + + self.nblocks = nblocks + + self.enable_torch_compile = enable_torch_compile + + # Whether to pack the layer immediately after tuning + # Managed via self.compress_context.is_immediate_packing / is_immediate_saving + + torch.set_printoptions(precision=3, sci_mode=True) + + if is_hpex_available(): + logger.info("habana_frameworks is available, import htcore explicitly.") + import habana_frameworks.torch.core as htcore # pylint: disable=E0401 + + # Reset both context singletons before creating fresh instances so that + # consecutive AutoRound creations don't inherit stale config from earlier ones. + CompressContext.reset_context() + ModelContext.reset_context() + + # Resolve the device eagerly so ModelContext can be created before + # CompressContext. Creating ModelContext first places the large model + # allocation early in the heap, matching the OLD arch allocation order + # and reducing C-heap fragmentation (which is amplified on HPU). + _device = get_major_device(device_map if device_map is not None else 0) + + self.model_context = ModelContext( + model, + tokenizer=tokenizer, + platform=platform, + model_dtype=model_dtype, + trust_remote_code=trust_remote_code, + amp=amp, + need_calib=self.need_calib, + device=_device, + formats=self.formats, + is_act_quantize=self.quantize_config.is_act_quantize, + quant_nontext_module=quant_nontext_module, + ) + # Alternatively, you can use CompressContext.create_context + self.compress_context = CompressContext( + low_cpu_mem_usage, + low_gpu_mem_usage, + device_map, + enable_torch_compile, + formats=self.formats, + static_kv_dtype=self.static_kv_dtype, + static_attention_dtype=self.static_attention_dtype, + ) + self.shard_writer = None + + # scale_dtype is resolved in quantizer.resolve_scheme() after scheme resolution, + # so it is not initialized here to avoid premature evaluation with an unresolved scheme. + + # Flag for post_init idempotency. Set to False here so post_init() can be called + # either via quantize_and_save() (preferred, outside inference_mode) or directly + # from quantize() as a fallback for non-AutoScheme cases. + self._post_init_done = False + + # Apply torch compile adjustments eagerly so that ar.enable_torch_compile + # reflects the correct value immediately after construction (not only after post_init). + self._adjust_torch_compile(enable_torch_compile) + self.compress_context.enable_torch_compile = self.enable_torch_compile + + self.blocks_requiring_input_ids = [] + self.has_variable_block_shape = False + fixed_attr = get_predefined_fixed_attr(self.model_context.model) or {} + for key, value in fixed_attr.items(): + setattr(self, key, value) + + # ── Scheme resolution ───────────────────────────────────────────────────── + + def resolve_scheme(self, model_context=None, compress_context=None, dataset: str = None) -> None: + """Phase-1 init: resolve scheme and bind config attrs (no model structure needed). + + Must be called BEFORE ``get_formats()`` and BEFORE ``_scheme_post_init()``. + Idempotent: safe to call multiple times. + """ + if self._scheme_resolved: + return + + if model_context is not None: + self.model_context = model_context + if compress_context is not None: + self.compress_context = compress_context + if dataset is not None: + self.dataset = dataset + + scheme_fields = {f.name for f in fields(QuantizationScheme)} + user_scheme_overrides = { + k: getattr(self.quantize_config, k) + for k in scheme_fields + if getattr(self.quantize_config, k, None) is not None + } + default_scheme, self.is_auto_scheme, final_attrs = _parse_scheme(self.scheme, user_scheme_overrides) + + for key, value in final_attrs.items(): + setattr(self.quantize_config, key, value) + if hasattr(self, key): + setattr(self, key, value) + self.quantize_config.check_config() + + self.orig_scheme = copy.deepcopy(self.scheme) + self.scheme = default_scheme + + gguf_scheme_name = get_gguf_scheme(self.scheme) + if self.scale_dtype is None: + self.scale_dtype = "fp32" if gguf_scheme_name else "fp16" + self.scale_dtype = convert_dtype_str2torch(self.scale_dtype) + + self._scheme_resolved = True + + def _scheme_post_init(self) -> None: + """Phase-4 init: build layer config on the patched model. + + Requires ``resolve_scheme()`` to have been called first. + Must be called AFTER ``model_context.apply_patches()``. + """ + assert self._scheme_resolved, ( + "resolve_scheme() must be called before _scheme_post_init(). " + "BaseCompressor.post_init() does this automatically." + ) + + enable_gguf_official_mixed = not self.is_auto_scheme + + if self.quant_block_list is None: + quant_nontext_module = getattr(self.model_context, "quant_nontext_module", False) + all_blocks = get_block_names(self.model_context.model, quant_vision=quant_nontext_module) + self.quant_block_list = find_matching_blocks( + self.model_context.model, all_blocks, self.to_quant_block_names + ) + if self.to_quant_block_names is None and self.quant_block_list: + self.to_quant_block_names = extract_block_names_to_str(self.quant_block_list) + self.quantize_config.to_quant_block_names = self.to_quant_block_names + + self.configure_layer_config(enable_gguf_official_mixed=enable_gguf_official_mixed) + + def _gen_auto_scheme(self) -> dict[str, dict]: + """Generate per-layer config via AutoScheme delta-loss selection.""" + if self.model_context.is_mllm: + # AutoScheme on a VLM only scores the language tower (the block + # walker in delta_loss already skips vision/audio sub-trees) and + # uses a pure-text calibration dataset by default, falling back to + # the multimodal dataloader if the VLM rejects text-only forward. + logger.info( + "AutoScheme on multimodal LLM: scoring the language tower only " + "with text-only calibration (multimodal dataloader will be used " + "as a fallback if needed)." + ) + + if is_quantized_input_module(self.model_context.model): + raise NotImplementedError("AutoScheme does not currently support quantized input models (e.g., FP8).") + + all_dtypes = [] + all_gguf = True + for option in self.orig_scheme.options: + dtype = "int" + if isinstance(option, str): + if not option.lower().startswith("gguf"): + all_gguf = False + option = preset_name_to_scheme(option) + else: + all_gguf = False + + if isinstance(option, QuantizationScheme): + dtype = option.data_type + elif isinstance(option, dict): + dtype = option.get("data_type", "int") + + all_dtypes.append(dtype) + + unique_dtypes = set(all_dtypes) + if len(unique_dtypes) > 1 and not all_gguf: + logger.warning( + "Models with mixed data_types " + "cannot yet be exported to real formats except GGUF. " + "Please save the model using the `fake` format for now." + ) + + layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config( + self.model_context.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.ignore_layers, + self.quant_lm_head, + enable_gguf_official_mixed=False, + is_mllm=self.model_context.is_mllm, + ) + quant_layer_names = layer_config.keys() + + # ---- VLM: peel non-text sub-trees AutoScheme should not score ---- # + nontext_skipped_layers: dict[str, dict] = {} + if self.model_context.is_mllm: + from auto_round.utils import get_block_names + + quant_nontext = getattr(self, "quant_nontext_module", False) + scoreable_blocks = get_block_names(self.model_context.model, quant_vision=quant_nontext) + scoreable_block_prefixes = tuple(blk for group in scoreable_blocks for blk in group) + if quant_nontext: + peel_markers = ("audio", "speech", "wav", "waveform") + tower_label = "language+vision" + peel_label = "audio/speech" + else: + peel_markers = ( + "vision", + "visual", + "image", + "img", + "audio", + "speech", + "wav", + "waveform", + ) + tower_label = "language" + peel_label = "vision/audio" + + def _is_scoreable_layer(name: str) -> bool: + if any(name == p or name.startswith(p + ".") for p in scoreable_block_prefixes): + return True + lname = name.lower() + return not any(marker in lname for marker in peel_markers) + + scoreable_layer_config = {} + for name, cfg in layer_config.items(): + if _is_scoreable_layer(name): + scoreable_layer_config[name] = cfg + else: + nontext_skipped_layers[name] = cfg + + if nontext_skipped_layers: + logger.info( + "AutoScheme (VLM): scoring %d %s-tower layers; " + "%d %s-tower layers kept at their original 16-bit configuration.", + len(scoreable_layer_config), + tower_label, + len(nontext_skipped_layers), + peel_label, + ) + layer_config = scoreable_layer_config + quant_layer_names = layer_config.keys() + + scheme_keys = {f.name for f in fields(QuantizationScheme)} + fixed_layer_scheme_new = { + k: {key: v[key] for key in scheme_keys & v.keys()} + for k, v in layer_config.items() + if v.get("fixed_by_user", False) + } + + from auto_round.auto_scheme.gen_auto_scheme import GenScheme + + if ( + not self.compress_context.enable_torch_compile + and self.quantize_config.super_bits is None + and not self.orig_scheme.low_gpu_mem_usage + ): + logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM") + self.scheme_generator = GenScheme( + self.orig_scheme, + self.model_context.model, + quant_layer_names, + fixed_layer_scheme_new, + self.dataset, + device_map=self.compress_context.device_map, + tokenizer=self.model_context.tokenizer, + enable_torch_compile=self.compress_context.enable_torch_compile, + processor=self.model_context.processor, + ) + layer_config = self.scheme_generator.get_layer_config() + # Re-attach vision/audio-tower layers we peeled off earlier so the + # downstream quantization pipeline sees the complete layer map. + if nontext_skipped_layers: + allowed_keys = {f.name for f in fields(QuantizationScheme)} | { + "fixed_by_user", + "scale_dtype", + "scheme", + } + for name, cfg in nontext_skipped_layers.items(): + clean_cfg = {k: v for k, v in cfg.items() if k in allowed_keys} if isinstance(cfg, dict) else cfg + layer_config.setdefault(name, clean_cfg) + return layer_config + + def configure_layer_config(self, enable_gguf_official_mixed: bool | None = True) -> None: + """Build ``self.layer_config`` from the resolved scheme on the patched model.""" + is_gguf_format = (f := getattr(self.compress_context, "formats", None)) is not None and "gguf" in f + predefined_ignore_layers = get_predefined_ignore_layers(self.model_context.model) + compressed_predefined_ignore_layers = compress_layer_names(predefined_ignore_layers) + if not is_gguf_format: + predefined_ignore_layers = get_predefined_ignore_layers(self.model_context.model) + if predefined_ignore_layers: + logger.info(f"Using predefined ignore_layers: {compressed_predefined_ignore_layers}") + tmp_str = ",".join(predefined_ignore_layers) + if self.ignore_layers == "": + self.ignore_layers = tmp_str + else: + self.ignore_layers += "," + tmp_str + + if self.is_auto_scheme: + self.layer_config = self._gen_auto_scheme() + else: + self.layer_config = _handle_special_schemes( + self.orig_scheme, + self.layer_config, + self.model_context.model, + supported_types=SUPPORTED_LAYER_TYPES, + inner_supported_types=INNER_SUPPORTED_LAYER_TYPES, + quant_lm_head=self.quant_lm_head, + mllm=self.model_context.is_mllm, + ) + _gguf_orig_fmt = getattr(self, "_gguf_original_format_name", None) + if _gguf_orig_fmt and "_MIXED" in _gguf_orig_fmt.upper(): + self.layer_config = _handle_special_schemes( + _gguf_orig_fmt.lower(), + self.layer_config, + self.model_context.model, + supported_types=SUPPORTED_LAYER_TYPES, + inner_supported_types=INNER_SUPPORTED_LAYER_TYPES, + quant_lm_head=self.quant_lm_head, + mllm=self.model_context.is_mllm, + ) + + fill_default_value = not self.is_auto_scheme + self.layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config( + self.model_context.model, + self.layer_config, + self.scheme, + self.scale_dtype, + SUPPORTED_LAYER_TYPES, + INNER_SUPPORTED_LAYER_TYPES, + self.quant_block_list, + self.ignore_layers, + self.quant_lm_head, + enable_gguf_official_mixed=enable_gguf_official_mixed, + is_mllm=self.model_context.is_mllm, + fill_default_value=fill_default_value, + gguf_format_name=getattr(self, "_gguf_format_name", None), + ) + + # ───────────────────────────────────────────────────────────────────────── + + @property + def mllm(self): + return self.model_context.is_mllm + + @property + def diffusion(self): + return self.model_context.is_diffusion + + def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: + """Sets the torch compile configuration for the tuning.""" + self.enable_torch_compile = enable_torch_compile + + # Determine fp8 / nvfp4 intent from raw config before scheme resolution. + cfg = self.quantize_config + raw_scheme = self.scheme if isinstance(self.scheme, str) else "" + raw_dt = (cfg.data_type or "").lower() + raw_adt = (cfg.act_data_type or "").lower() + raw_scheme_upper = raw_scheme.upper() + + is_raw_nv_fp = "nv_fp" in raw_dt or "nv_fp" in raw_adt or "NVFP" in raw_scheme_upper + is_raw_fp8 = ( + "fp8" in raw_dt + or "fp8" in raw_adt + or "FP8" in raw_scheme_upper + or ("fp" in raw_dt and getattr(cfg, "bits", 16) == 8) + or ("fp" in raw_adt and getattr(cfg, "act_bits", 16) == 8) + ) + + act_bits = getattr(cfg, "act_bits", 16) or 16 + if ( + not self.enable_torch_compile + and TORCH_VERSION_AT_LEAST_2_6 + and act_bits > 8 + and not is_debug_mode() + and not is_raw_fp8 + and self.need_calib + ): + logger.info( + "%s", + "'enable_torch_compile' is set to `False` by default. " + "Enabling it can reduce tuning cost by 20%, but it might throw an exception.", + ) + # On HPU, we rely on torch.compile to speed up the model execution. + if self.enable_torch_compile and is_raw_fp8 and not is_hpex_available(): + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") + # TODO: fix https://github.com/intel/auto-round/issues/1109 + if self.enable_torch_compile and is_raw_nv_fp: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as nvfp4 is enabled") + super_group_size = getattr(cfg, "super_group_size", None) + enable_alg_ext = getattr(cfg, "enable_alg_ext", False) + if self.enable_torch_compile and super_group_size is not None and enable_alg_ext: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as super_group_size is set for algorithm extension") + + def _get_calibration_dataset(self) -> str: + """Resolve calibration dataset: self.dataset > AutoScheme.dataset > default.""" + dataset = self.__dict__.get("dataset", None) + if dataset: + return dataset + from auto_round.auto_scheme.gen_auto_scheme import AutoScheme + + scheme = self.scheme + if isinstance(scheme, AutoScheme) and scheme.dataset: + return scheme.dataset + return "NeelNanda/pile-10k" + + def post_init(self) -> None: + """One-time initialization that requires a loaded model. + + Call this OUTSIDE any ``@torch.inference_mode()`` context when using + AutoScheme – delta-loss selection needs autograd (backward pass). + ``quantize_and_save()`` does this automatically before entering the + inference-mode quantize loop. + + Delegates to five ordered pipeline phases; see each ``_resolve_scheme``, + ``_resolve_formats``, ``_patch_model``, ``_build_layer_config``, and + ``_hardware_setup`` for the precise preconditions and postconditions. + """ + if self._post_init_done: + return + + self._resolve_scheme() + + # After scheme resolution, is_act_quantize is known. When activation + # quantization is enabled and the model is in float16, convert to + # bfloat16 to match the old arch. This also detaches any parameter + # tensors that are still backed by safetensors' mmap, preventing + # per-block RSS growth (~14 MB/block) when .to(device) page-faults + # the underlying file pages into physical memory. + if self.quantize_config.is_act_quantize and self.model_context.amp_dtype == torch.float16: + logger.warning("force to use bf16 for quantization tuning when enabling activation quantization") + self.model_context.amp_dtype = torch.bfloat16 + if self.model_context.model.dtype != torch.bfloat16: + self.model_context.model = self.model_context.model.to(torch.bfloat16) + + self._resolve_formats() + self._patch_model() + self._build_layer_config() + + # Reclaim temporaries from Phases 1-4 (scheme resolution, format + # parsing, model patching, layer-config walk) before Phase 5 + # allocates hardware/compile objects. This compacts the heap so that + # the fragmentation gap between live and freed blocks is minimised. + gc.collect() + _force_trim_malloc() + + self._hardware_setup() + + # Final trim after all init phases. + gc.collect() + _force_trim_malloc() + + self._post_init_done = True + + # ── Pipeline phase methods ──────────────────────────────────────────────── + + def _resolve_scheme(self) -> None: + """Phase 1 – Scheme resolution and quantizer construction. + + Preconditions: + - ``self.quantize_config`` is a valid :class:`QuantizationConfig`. + + Work performed: + - Seeds scheme-related attrs (``scale_dtype``, ``ignore_layers``, + ``quant_lm_head``, ``to_quant_block_names``) from ``quantize_config``. + - Calls :meth:`resolve_scheme` to derive ``data_type``, ``bits``, + ``sym``, ``scale_dtype`` etc. and write them back to both ``self`` + and ``self.quantize_config``. + - Constructs ``self.quantizer`` from the now-resolved config and wires + it to the current model / context. + - Binds ``self.wrapper_block`` for later use in quantizers. + + Postconditions: + - ``self.scheme`` and ``self.quantize_config`` carry resolved scheme attrs. + - ``self.quantizer`` is ready; calibration params (``seqlen``, + ``nsamples``) are synced. + """ + cfg = self.quantize_config + self.scale_dtype = cfg.scale_dtype + # self.layer_config is already set from __init__ (direct compressor param). + self.ignore_layers = cfg.ignore_layers + self.quant_lm_head = cfg.quant_lm_head + self.to_quant_block_names = cfg.to_quant_block_names + + # Resolve the scheme (pure config work: sets data_type / bits / sym / + # scale_dtype etc. on both self and self.quantize_config). + self.resolve_scheme( + model_context=self.model_context, + compress_context=self.compress_context, + dataset=self._get_calibration_dataset(), + ) + + # Create the quantizer now that the config holds resolved values. + self.quantizer = BaseQuantizers.from_config(self.quantize_config) + self.quantizer.model_context = self.model_context + self.quantizer.compress_context = self.compress_context + self.quantizer.model = self.model_context.model + self.quantizer.scale_dtype = self.scale_dtype + # Sync compressor-owned calibration params to quantizer. + self.quantizer.seqlen = self.seqlen + self.quantizer.nsamples = self.nsamples + self.wrapper_block = wrapper_block + + def _resolve_formats(self) -> None: + """Phase 2 – Format resolution, GGUF attr sync, and rotation application. + + Preconditions: + - Phase 1 complete: ``self.quantizer`` is initialised and the scheme + is resolved (``data_type``, ``bits``, ``sym`` etc. are final). + + Work performed: + - Converts a string ``self.formats`` to a list of + :class:`~auto_round.formats.OutputFormat` objects via + :func:`~auto_round.formats.get_formats`. + - Initialises :class:`~auto_round.compressors_new.shard_writer.ShardWriter` + when formats are present. + - **(2b)** Detects GGUF-driven attribute mutations (``bits``, ``sym``, + ``data_type``, ``group_size``, etc.) that ``gguf_args_check`` may + have written onto ``self`` inside ``get_formats``, syncs them to + ``self.quantizer``, and rebuilds ``self.scheme`` accordingly. + - Merges any GGUF-injected entries into ``self.layer_config``. + - **(2d)** Applies rotation transforms from ``self.transform_configs``. + + Postconditions: + - ``self.formats`` is a list (or ``None``). + - ``self.compress_context.formats`` mirrors ``self.formats``. + - ``self.quantizer`` carries the GGUF-adjusted scheme attrs. + - ``self.scheme`` is consistent with the final quantization attrs. + - All rotation transforms have been applied to ``self.model_context.model``. + """ + # get_formats() inspects data_type / bits etc. that were just resolved. + if isinstance(self.formats, str): + self.formats = get_formats(self.formats, self) + if self.formats is not None: + self.compress_context.formats = self.formats + ShardWriter.reset() + # Defer ShardWriter construction to _ensure_shard_writer() to avoid + # heap fragmentation during post_init (parameter iteration). + + # Snapshot the user-specified layer_config before GGUF processing may + # add extra entries, so we can distinguish them later in step 2b. + _pre_gguf_layer_config = copy.copy(self.layer_config) or {} + + # ── 2b: propagate GGUF-adjusted attrs back to quantizer ────────────── + # gguf_args_check (called inside get_formats) may have overridden + # bits / sym / data_type / super_bits / super_group_size / group_size + # on *this* BaseCompressor object. The quantizer stored its own copies + # from Phase 1 (resolve_scheme), so we must sync them now, before + # _scheme_post_init() builds the layer_config in Phase 4. + _gguf_forwarded_attrs = ( + "bits", + "sym", + "data_type", + "super_bits", + "super_group_size", + "group_size", + "act_bits", + "scale_dtype", + ) + _any_gguf_attr_changed = False + for _attr in _gguf_forwarded_attrs: + if _attr in self.__dict__ and hasattr(self.quantizer, _attr): + if _attr not in ("scale_dtype", "act_bits") and getattr(self.quantizer, _attr) != self.__dict__[_attr]: + _any_gguf_attr_changed = True + setattr(self.quantizer, _attr, self.__dict__[_attr]) + # If gguf_args_check changed scheme attrs, rebuild the scheme on both + # the compressor (SchemeMixin) and the quantizer so that + # configure_layer_config() and set_layer_config() use the correct + # default_dict and gguf_name. + if _any_gguf_attr_changed: + from auto_round.schemes import PRESET_SCHEMES + from auto_round.schemes import QuantizationScheme as _QS + + # Prefer to derive the scheme directly from the gguf format name to + # avoid ambiguity (e.g. Q4_K_S and Q4_K_M share identical weight attrs). + _gguf_preset_scheme = None + _gguf_fmt_name = None + _gguf_original_fmt_name = None + for _fmt in self.formats or []: + # GGUFFormat (outer) has output_format="gguf" but backend.output_format="gguf:q4_k_m" + # GGUFFormat (inner/standalone) has output_format="gguf:q4_k_m" + _of = getattr(_fmt, "output_format", "") + if "gguf" in str(_of): + if str(_of) == "gguf": + # outer GGUFFormat: full format in _original_format (e.g. "gguf:q2_k_mixed") + # or backend.output_format (e.g. "gguf:q2_k_s" after _mixed → _s conversion) + _orig = getattr(_fmt, "_original_format", None) + if _orig: + _gguf_original_fmt_name = str(_orig).upper() + _backend = getattr(_fmt, "backend", None) + _of = getattr(_backend, "output_format", _of) if _backend is not None else _of + _preset_key = str(_of).upper() + if _preset_key in PRESET_SCHEMES: + _gguf_preset_scheme = PRESET_SCHEMES[_preset_key] + _gguf_fmt_name = _preset_key + break + if _gguf_preset_scheme is not None: + # Update scheme on both compressor and quantizer. + self.scheme = _gguf_preset_scheme + # Store the exact gguf format name so configure_layer_config / + # set_layer_config can use it directly, avoiding Q4_K_S / Q4_K_M ambiguity. + self._gguf_format_name = _gguf_fmt_name + # Store original format name (may include _mixed) for _handle_special_schemes + if _gguf_original_fmt_name: + self._gguf_original_format_name = _gguf_original_fmt_name + else: + _new_scheme_dict = {f.name: getattr(self, f.name, None) for f in fields(_QS)} + _new_scheme = _QS.from_dict({k: v for k, v in _new_scheme_dict.items() if v is not None}) + self.scheme = _new_scheme + + _gguf_layer_cfg = { + k: v for k, v in (self.__dict__.get("layer_config") or {}).items() if k not in (_pre_gguf_layer_config) + } + if _gguf_layer_cfg: + if self.layer_config is None: + self.layer_config = {} + for _lname, _lval in _gguf_layer_cfg.items(): + self.layer_config.setdefault(_lname, _lval) + + # ── 2d: apply rotation transforms ──────────────────────────────────── + if self.transform_configs: + logger.info("Applying Hadamard transform to the model.") + for rotation_cfg in self.transform_configs: + self.model_context.model = apply_rotation( + self.model_context.model, + rotation_cfg, + data_type=self.quantize_config.data_type, + ) + + def _patch_model(self) -> None: + """Phase 3 – Model structure patching. + + Preconditions: + - Phase 2 complete: ``self.formats`` is resolved so that + ``apply_patches`` can inspect format-specific requirements. + + Work performed: + - Delegates to :meth:`~auto_round.context.model.ModelContext.apply_patches` + which may replace or merge layers (e.g. MoE expert merging, adding + static-kv wrappers) to produce the final model topology. + + Postconditions: + - ``self.model_context.model`` reflects the definitive topology that + :meth:`_build_layer_config` will walk. + """ + # apply_patches() may replace layers (e.g. MoE expert merging); must + # happen before configure_layer_config() so it sees the final topology. + self.model_context.apply_patches(self.formats) + + def _build_layer_config(self) -> None: + """Phase 4 – Layer-config construction and quantizer sync. + + Preconditions: + - Phase 3 complete: model topology is final. + - ``self.scheme`` and all scheme-resolved attrs are consistent with + the (possibly GGUF-adjusted) values set in Phase 2. + + Work performed: + - Calls :meth:`_scheme_post_init` which walks the patched model to + build ``self.layer_config``, ``self.quant_block_list``, etc. + On the AutoScheme path this also runs delta-loss forward/backward + passes to select per-layer schemes. + - Syncs the fully-resolved ``layer_config`` and related attrs to + ``self.quantizer`` so quantization methods have the complete view. + + Postconditions: + - ``self.layer_config`` is fully populated. + - ``self.quantizer`` mirrors ``layer_config``, ``has_qlayer_outside_block``, + ``regex_config``, ``quant_block_list``, ``to_quant_block_names``, + ``scale_dtype``, and ``ignore_layers``. + """ + # configure_layer_config() walks the patched model; _gen_auto_scheme() + # (AutoScheme path) runs delta-loss forward+backward passes. + self._scheme_post_init() + + # Sync the fully-resolved scheme state to the quantizer so that + # quantization methods (quantize_block, quantize_layer, etc.) have + # access to layer_config, scale_dtype, quant_block_list, etc. + self.quantizer.layer_config = self.layer_config + self.quantizer.has_qlayer_outside_block = self.has_qlayer_outside_block + self.quantizer.regex_config = self.regex_config + self.quantizer.quant_block_list = self.quant_block_list + self.quantizer.to_quant_block_names = self.to_quant_block_names + self.quantizer.scale_dtype = self.scale_dtype + self.quantizer.ignore_layers = self.ignore_layers + + def _hardware_setup(self) -> None: + """Phase 5 – Hardware and compile configuration. + + Preconditions: + - Phase 4 complete: ``layer_config`` is built and + ``has_qlayer_outside_block`` is known. + - ``self.quantize_config.data_type`` is the final resolved value + (needed by :meth:`_adjust_torch_compile`). + + Work performed: + - Applies the device map via :func:`~auto_round.utils.device.set_non_auto_device_map`. + - Re-evaluates ``torch.compile`` eligibility now that ``data_type`` is + resolved and writes the result back to ``compress_context``. + - Selects ``self.block_forward`` (compiled or plain). + - Resets the offload manager when ``low_cpu_mem_usage`` is active. + - Disables ``self.inplace`` when quantized layers live outside + transformer blocks (incompatible with in-place rewriting). + - Calls :meth:`_adjust_immediate_packing_and_saving` to decide whether + layers should be packed / written immediately after each block. + + Postconditions: + - ``self.block_forward`` is ready for use. + - ``compress_context.enable_torch_compile`` is final. + - ``self.inplace`` and ``compress_context.is_immediate_packing`` / + ``compress_context.is_immediate_saving`` are set to their definitive values. + """ + set_non_auto_device_map(self.model_context.model, self.compress_context.device_map) + # Re-evaluate torch.compile eligibility now that data_type is resolved. + self._adjust_torch_compile(self.enable_torch_compile) + self.compress_context.enable_torch_compile = self.enable_torch_compile + # Apply the same act-quantization / alg-ext guard as + # _resolve_block_forward() so we never compile when hooks are present. + cfg = self.quantize_config + _needs_plain_forward = (cfg.is_act_quantize and (not cfg.act_dynamic or cfg.is_act_nv_fp)) or getattr( + cfg, "enable_alg_ext", False + ) + # Only compile block_forward when it will actually be used (calibration path). + # For zero-shot compressors (need_calib=False), block_forward is never called, + # so skipping compilation avoids unnecessary HPU workspace allocation. + if self.enable_torch_compile and not _needs_plain_forward and self.need_calib: + self.block_forward = compile_func(block_forward, self.compress_context.device) + else: + self.block_forward = block_forward + if self.compress_context.low_cpu_mem_usage: + self._offloader.reset() + + # Disable inplace when quantized layers live outside transformer blocks. + if self.has_qlayer_outside_block and self.need_calib: + self.inplace = False + + if not hasattr(self, "formats"): + logger.warning("this API is deprecated, please use `quantize_and_save` instead") + else: + self._adjust_immediate_packing_and_saving() + + # backward compatible with the legacy API + def __getattr__(self, name: str) -> Any: + if name in self.__dict__: + return self.__dict__[name] + + for obj in ["quantizer", "quantize_config", "model_context", "compress_context"]: + if obj not in self.__dict__: + continue + obj = object.__getattribute__(self, obj) + try: + return object.__getattribute__(obj, name) + except AttributeError: + continue + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + @property + def optimizer(self): + """Return the actual optimizer class, converting string to class for backward compat. + + Old API stored ``self.optimizer = torch.optim.AdamW`` (the class itself). + New arch stores the optimizer name as a string in ``quantize_config.optimizer``. + This property converts it so that ``ar.optimizer == torch.optim.AdamW`` works. + """ + if self.quantize_config is None: + return None + opt = getattr(self.quantize_config, "optimizer", None) + if opt is None: + # Default to AdamW when enable_adam=True and no explicit optimizer was set + if getattr(self.quantize_config, "enable_adam", False): + return torch.optim.AdamW + return None + if isinstance(opt, str): + return getattr(torch.optim, opt, None) + return opt + + def _adjust_immediate_packing_and_saving(self): + from auto_round.algorithms.quantization.rtn.config import RTNConfig + + if self.formats is None: + return + + formats = getattr(self, "formats", []) + if len(formats) == 1 and not formats[0].is_fake() and self.inplace: + self.compress_context.is_immediate_packing = True + + if self.has_qlayer_outside_block and self.need_calib: + self.compress_context.is_immediate_packing = False + + if not ("causallm" in self.model_context.model.__class__.__name__.lower() and not self.model_context.is_mllm): + # TODO For tied keys, there may some issues, we haven't not verified this + tied_weight_keys = getattr(self.model_context.model, "_tied_weight_keys", {}) + if len(tied_weight_keys) > 1: + self.compress_context.is_immediate_saving = False + if self.compress_context.low_cpu_mem_usage: + logger.warning("reset low_cpu_mem_usage to False due to tied weights") + return + if len(tied_weight_keys) == 1: + key = list(tied_weight_keys.keys())[0] + if "lm_head" not in key: + self.compress_context.is_immediate_saving = False + if self.compress_context.low_cpu_mem_usage: + logger.warning("reset low_cpu_mem_usage to False due to tied weights") + return + + if self.compress_context.low_cpu_mem_usage and self.compress_context.is_immediate_packing: + self.compress_context.is_immediate_saving = True + + if self.compress_context.low_cpu_mem_usage and self.compress_context.is_immediate_packing: + if formats[0].is_gguf(): + logger.warning( + "`low_cpu_mem_usage` is not fully supported for gguf format. " + "Setting `low_cpu_mem_usage` to False." + ) + self.compress_context.low_cpu_mem_usage = False + self.compress_context.is_immediate_saving = False + elif ( + self.has_qlayer_outside_block + and getattr(self, "disable_opt_rtn", None) + and isinstance(self.quantize_config, RTNConfig) + ): + logger.info( + "Keeping `low_cpu_mem_usage` enabled in RTN mode (iters=0): " + "RTN path uses blockwise quantization and supports per-block offloading." + ) + elif self.has_qlayer_outside_block and not isinstance(self.quantize_config, RTNConfig): + logger.warning( + "`low_cpu_mem_usage` is not fully supported " + "when there are quantized layers outside blocks and optimized RTN is disabled. " + "Setting low_cpu_mem_usage to False." + ) + self.compress_context.low_cpu_mem_usage = False + self.compress_context.is_immediate_saving = False + + if self.compress_context.is_immediate_saving and not ( + "int" in self.quantize_config.data_type + or is_nv_fp(self.quantize_config.data_type) + or is_mx_fp(self.quantize_config.data_type) + ): + logger.warning("immediate_saving is only supported for int/nv_fp/mx_fp quantization, set to False") + self.compress_context.is_immediate_saving = False + + if self.output_dir is None: + self.compress_context.is_immediate_saving = False + + # Create ShardWriter eagerly only when immediate saving is active + # (it interleaves with the quantize loop). Otherwise keep it deferred + # until save_quantized() to avoid heap fragmentation during init. + if self.compress_context.is_immediate_saving: + self._ensure_shard_writer() + + def _ensure_shard_writer(self): + """Lazily create ShardWriter if it hasn't been created yet.""" + if self.shard_writer is None and self.formats is not None: + self.shard_writer = ShardWriter(self.model_context.model, bits=8) + + def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: + """Quantize the model and return the quantized model along with layer configurations.The entry of AutoRound. + Returns: + The quantized model and layer configurations. + """ + raise NotImplementedError("quantize method must be implemented in subclass") + + def save_quantized( + self, + output_dir: str = None, + format: Union[str, list[OutputFormat]] = None, + inplace: bool = True, + return_folders=False, + **kwargs, + ) -> torch.nn.Module: + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + + Returns: + object: The compressed model object. + """ + self.output_dir = output_dir + if output_dir is not None: + self.compress_context.output_dir = output_dir + if format is not None: + if isinstance(format, str) and getattr(self, "formats", None) is None: + logger.warning( + f"save_quantized with format is deprecated and will be deleted in auto_round version 1.0." + f" Please use Compressor(format='{format}' instead)." + ) + self.formats = get_formats(format, self) + self.compress_context.formats = self.formats + + if not self.model_context.quantized: + logger.warning("please run autoround.quantize first") + return + folders = [] + if self.formats is None: + logger.info("format is not set, using default auto_round format.") + self.formats = "auto_round" + if isinstance(self.formats, str): + self.formats = get_formats(self.formats, self) + self.compress_context.formats = self.formats + for format in self.formats: + save_folder = _get_save_folder_name(format) + if self.act_bits <= 8 and format.is_fake(): + logger.warning( + "Support for exporting activation quantization is limited. " + "Please ensure that your configuration is supported." + ) + + serialization_dict = asdict(SerializedCompressorConfig()) + for key in serialization_dict: + serialization_dict[key] = getattr(self, key, serialization_dict[key]) + from auto_round.version import __version__ + + serialization_dict["autoround_version"] = __version__ + if serialization_dict.get("to_quant_block_names") is None and self.quantizer.quant_block_list: + serialization_dict["to_quant_block_names"] = extract_block_names_to_str(self.quantizer.quant_block_list) + if "scale_dtype" in serialization_dict.keys(): + serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) + + compressed_model = format.save_quantized( + save_folder, + model=self.model_context.model, + layer_config=self.quantizer.layer_config, + inplace=inplace, + tokenizer=self.model_context.tokenizer, + device=self.compress_context.device, + serialization_dict=serialization_dict, + **kwargs, + ) + folders.append(save_folder) + + if return_folders: + if len(folders) == 1: + folders = folders[0] + return compressed_model, folders + else: + return compressed_model + + def _get_export_dir(self, output_dir: str, format_str: str) -> str: + """Derive a descriptive export directory from model name and quantization config. + + Must be called after ``post_init()`` so that scheme-resolved attrs + (bits, group_size, data_type, etc.) are available on ``self.quantize_config``. + + Mirrors the logic previously in ``__main__.py`` so callers only need to + pass the base ``output_dir`` and the format string. + """ + # Diffusion models use save_quantized from DiffusionMixin which manages its own + # directory layout (model_index.json + per-component subdirs). Appending a + # scheme-derived suffix here would place files one level too deep. + if getattr(self, "diffusion", False): + return output_dir + + model_name = (getattr(self.model_context.model, "name_or_path", "") or "").rstrip("/") + cfg = self.quantize_config + group_size = cfg.group_size + bits = cfg.bits + data_type = cfg.data_type or "int" + act_bits = cfg.act_bits or 16 + act_data_type = cfg.act_data_type or "float" + + is_gguf = "gguf" in (format_str or "") + last = model_name.split("/")[-1].strip(".") + + if last == "" and not is_gguf: + # model path is just '.' or './' – put inside output_dir with suffix + if group_size <= 0: + suffix = f"afp{act_bits}" if "fp" in act_data_type else f"a{act_bits}" + else: + suffix = f"g{group_size}" + return os.path.join(output_dir, f"w{bits}{suffix}") + + if last == "" and is_gguf: + return output_dir + + if is_gguf: + return os.path.join(output_dir, model_name.split("/")[-1] + "-gguf") + + # Normal case: derive suffix from group_size / act config + if isinstance(group_size, tuple): + assert len(group_size) == 2, f"Only support 2D group_size, but got {group_size}" + suffix = f"g{group_size[0]}x{group_size[1]}" + elif group_size <= 0: + suffix = f"afp{act_bits}" if "fp" in act_data_type else f"a{act_bits}" + else: + suffix = f"g{group_size}" + + prefix = data_type.lower().replace("_", "") if "int" not in data_type or "mx" in data_type else "" + return os.path.join( + output_dir, + model_name.split("/")[-1] + (f"-{prefix}" if prefix else "") + f"-w{bits}{suffix}", + ) + + def quantize_and_save( + self, output_dir: str = "tmp_autoround", format: str = None, inplace: bool = True, **kwargs + ) -> tuple[torch.nn.Module, dict[str, Any]]: + """Quantizes the model and saves it in the specified format(s). + + This function checks the validity of the requested format(s), quantizes + the model accordingly, and saves it to the specified output directory. + If multiple formats are provided, the model is saved separately for each format. + + Args: + output_dir (str, optional): The directory where the quantized model + will be saved. Defaults to "tmp_autoround". + format (str, optional): The quantization format(s) to use, separated + by commas if multiple. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place if only + one format is used. Defaults to True. + **kwargs: Additional arguments for the quantization and saving process. + + Returns: + model: A qdq model or packed model based on the configurations + folders: The folder paths where the quantized models are saved. + + Raises: + ValueError: If an unsupported format is specified. + """ + # Validate and process the specified formats + self.output_dir = output_dir + self.compress_context.output_dir = output_dir + + # check and update the format based on the current configuration + if format and self.formats is None: + logger.warning( + f"quantize_and_save with format is deprecated and will be deleted in auto_round version 1.0." + f" Please use Compressor(format='{format}' instead)." + ) + self.formats = format + if self.formats is None: + logger.info("format is not set, using default auto_round format.") + self.formats = "auto_round" + + # If multiple formats are specified, enforce inplace=False + if len(self.formats.split(",")) > 1: + inplace = False + self.inplace = kwargs.get("inplace", inplace) + kwargs.pop("inplace", None) + + # Perform model quantization + # IMPORTANT: post_init() must run outside any @torch.inference_mode() context + # because AutoScheme's delta-loss selection requires gradient tracking. + self.post_init() + # If post_init() was called manually before quantize_and_save() (e.g. ar.post_init() + # in tests), _resolve_formats saw formats=None and was a no-op. Now that we have set + # self.formats to a default string above, resolve it into OutputFormat objects so that + # quantize() and save_quantized() receive proper objects, not a raw string. + if isinstance(self.formats, str): + self.formats = get_formats(self.formats, self) + self.compress_context.formats = self.formats + # Derive descriptive export dir after post_init so scheme-resolved attrs are available. + _fmt_str = format or (self.formats if isinstance(self.formats, str) else "") + output_dir = self._get_export_dir(output_dir, _fmt_str) + self.output_dir = output_dir + self.compress_context.output_dir = output_dir + if self.static_attention_dtype is not None: + from auto_round.experimental.attention import attention_quant_ctx + + with attention_quant_ctx(self.model_context.model, static_attention_dtype=self.static_attention_dtype): + self.quantize() + self.model_context.quantized = True + elif self.static_kv_dtype is not None: + from auto_round.experimental.kv_cache import kvcache_quant_context + + with kvcache_quant_context(self.model_context.model, static_kv_dtype=self.static_kv_dtype): + self.quantize() + self.model_context.quantized = True + else: + self.quantize() + self.model_context.quantized = True + + # Ensure ShardWriter is ready before saving (deferred from post_init). + self._ensure_shard_writer() + + # Save the quantized model in the specified format_list + model, folders = self.save_quantized(output_dir, inplace=inplace, return_folders=True, **kwargs) + memory_monitor.log_summary() + + return model, folders diff --git a/auto_round/compressors_new/calib.py b/auto_round/compressors_new/calib.py new file mode 100644 index 000000000..dea8233df --- /dev/null +++ b/auto_round/compressors_new/calib.py @@ -0,0 +1,1684 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import gc +import time +import traceback +from functools import partial +from typing import Any, Callable, Optional, Union + +import accelerate +import torch +from accelerate.big_modeling import dispatch_model, infer_auto_device_map +from accelerate.utils import get_balanced_memory, get_max_memory +from tqdm import tqdm + +from auto_round import envs +from auto_round.algorithms.alg_config import AlgConfig +from auto_round.calibration.utils import ( + _infer_last_cache_name, + _update_inputs, +) +from auto_round.compressors_new.base import BaseCompressor +from auto_round.compressors_new.utils import ( + _get_quantized_layer_names_outside_blocks, + check_skippable_keywords, + immediate_pack, + init_cache, + is_nv_fp, + is_static_wfp8afp8, + reset_params, +) +from auto_round.logger import logger +from auto_round.modeling.fused_moe.replace_modules import materialize_model_, safe_to_cpu_ +from auto_round.utils import ( + SUPPORTED_LAYER_TYPES, + check_seqlen_compatible, + check_to_quantized, + clear_memory, + compress_layer_names, + convert_module_to_hp_if_necessary, + flatten_list, + get_block_names, + get_module, + hook_ngram_embeddings_on_cpu, + is_auto_device_mapping, + is_quantized_input_module, + memory_monitor, + mv_module_from_gpu, + set_amax_for_all_moe_layers, + set_module, + to_device, + to_dtype, + wrap_block_forward_positional_to_kwargs, +) +from auto_round.utils.device import ( + _force_trim_malloc, + parse_available_devices, +) +from auto_round.wrapper import WrapperLinear, WrapperMultiblock + + +class CalibCompressor(BaseCompressor): + need_calib: bool = True + + def __init__( + self, + config: Union[AlgConfig, list[AlgConfig]], + model: Union[torch.nn.Module, str], + tokenizer=None, + platform="hf", + format=None, + dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + iters: int = 200, + low_gpu_mem_usage: bool = False, + device_map: Union[str, torch.device, int, dict] = 0, + enable_torch_compile: bool = False, + seed: int = 42, + low_cpu_mem_usage: bool = True, + **kwargs, + ): + self.dataset = dataset + self.iters = iters + super().__init__( + config=config, + model=model, + tokenizer=tokenizer, + platform=platform, + format=format, + low_gpu_mem_usage=low_gpu_mem_usage, + device_map=device_map, + enable_torch_compile=enable_torch_compile, + seed=seed, + low_cpu_mem_usage=low_cpu_mem_usage, + **kwargs, + ) + if iters == 0: + self.lr = 5e-3 + + @torch.no_grad() + def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Attempts to cache intermediate data on GPU, if failed, then using CPU. + + Args: + block_names (list): List of block names to cache data for. + nsamples (int): Number of samples to use for caching. + layer_names (list, optional): List of layer names to cache data for. Defaults to []. + last_cache_name (str, optional): Name of the last cache. Defaults to None. + + Returns: + all_inputs: Cached intermediate data. + + Raises: + Exception: If caching on GPU fails, switches to CPU and caches there. + """ + if is_quantized_input_module(self.model_context.model): + layer_names = [] + if layer_names is None: + layer_names = [] + + block_names = flatten_list(block_names) + self.blocks_requiring_input_ids = [data if isinstance(data, str) else data[0] for data in block_names] + + calibrate_on_cpu = False + cannot_calibrate_on_cpu = False + if self.compress_context.low_gpu_mem_usage or ( + len(block_names) == 1 + and len(layer_names) == 0 + and not self.quantizer.has_qlayer_outside_block + and (last_cache_name is None or last_cache_name in block_names) + and not getattr(self, "mllm", False) + ): + # low_gpu_mem_usage or calibrate only the embedding layer, which is also very fast on CPU + calibrate_on_cpu = True + try: + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=[], last_cache_name=last_cache_name + ) + except NotImplementedError as error: + error_msg = str(error) + if "flash_attn::" in error_msg and "CPU" in error_msg: + cannot_calibrate_on_cpu = True # fallback to GPU when flash attention is not supported on CPU + else: + raise error + + if not calibrate_on_cpu or cannot_calibrate_on_cpu: + try: + if any(p.device.type == "meta" for p in self.model_context.model.parameters()): + materialize_model_(self.model_context.model) + + if ( + hasattr(self.model_context.model, "hf_device_map") + and len(self.model_context.model.hf_device_map) > 1 + ): + self.model_context.model = dispatch_model( + self.model_context.model, device_map=self.model_context.model.hf_device_map + ) + else: + # Change this if new device is supported + if str(self.model_context.model.device) == "cpu" and ( + not self.compress_context.device.startswith("hpu") + ): + # type(self.model_context.model._no_split_modules) changes from list to set + # when transformers > 5.0 + no_split_modules = list(getattr(self.model_context.model, "_no_split_modules", [])) + devices = parse_available_devices(self.compress_context.device_map) + + max_memory = get_max_memory() + new_max_memory = {} + if "cpu" not in devices: + devices.append("cpu") + for device in devices: + if ":" in device: + device = int(device.split(":")[-1]) + elif device == "cpu": + device = "cpu" + elif isinstance(device, str): + device = 0 + else: + raise ValueError( + f"Unsupported device {device} in device_map: {self.compress_context.device_map}" + ) + if device not in max_memory: + # Skip devices that aee not reported by accelerate's max_memory. + # This is expected when a device is unavailable or cannot provide memory info. + continue + # Use 90% of the reported max memory to leave headroom for activations, + # temporary tensors, other processes, and allocator fragmentation, reducing + # the chance of runtime OOM while still utilizing most available memory. + new_max_memory[device] = max_memory[device] * 0.9 + + # If non-CPU devices were requested but none survived, fall back to CPU caching + # via the OOM handler below, avoiding unnecessary dispatch overhead. + requested_non_cpu = any((d != "cpu") for d in devices) + has_non_cpu_memory = any((k != "cpu") for k in new_max_memory) + if requested_non_cpu and not has_non_cpu_memory: + raise torch.OutOfMemoryError( + "No non-CPU device available in accelerate's reported memory. " + "Falling back to CPU caching." + ) + + # Keep ngram_embeddings on CPU + has_ngram_embeddings, raw_ngram_embeddings = hook_ngram_embeddings_on_cpu( + self.model_context.model + ) + new_max_memory = get_balanced_memory( + self.model_context.model, + max_memory=new_max_memory, + no_split_module_classes=no_split_modules, + ) + if hasattr(self.model_context.model, "tie_weights"): + self.model_context.model.tie_weights() + device_map = infer_auto_device_map( + self.model_context.model, + max_memory=new_max_memory, + no_split_module_classes=no_split_modules, + ) + if len(devices) > 1 and "cpu" in device_map.values(): + logger.warning( + "Some layers are offloaded to cpu, which may severely impact calibration speed." + " Please consider using more cards." + ) + + try: + + self.model_context.model = dispatch_model(self.model_context.model, device_map=device_map) + if has_ngram_embeddings: + self.model_context.model.model.ngram_embeddings = raw_ngram_embeddings + except ValueError as e: + if "offload_dir" in e.__str__(): + logger.warning( + f"Due to insufficient resources, disk is used to store the model." + f" `offload_dir={envs.AR_WORK_SPACE}`" + ) + self.model_context.model = dispatch_model( + self.model_context.model, device_map=device_map, offload_dir=envs.AR_WORK_SPACE + ) + else: + raise + else: + + self.model_context.model = self.model_context.model.to(self.compress_context.device) + + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name + ) + if ( + hasattr(self.model_context.model, "hf_device_map") + and len(self.model_context.model.hf_device_map) > 1 + ): + accelerate.hooks.remove_hook_from_submodules(self.model_context.model) + + except torch.OutOfMemoryError as e: + if cannot_calibrate_on_cpu: + raise e + cuda_error_msg = traceback.format_exc() + try: + logger.info("switch to cpu to cache block inputs") + self.compress_context.cache_device = torch.device("cpu") + if self.quantizer.has_qlayer_outside_block or self.__class__.__name__ == "AutoRoundMLLM": + logger.warning( + "we recommend using more GPUs in calibration." + " Otherwise, some layers may fall back to `rtn` mode, which can affect accuracy." + ) + accelerate.hooks.remove_hook_from_submodules(self.model_context.model) + self.model_context.model = mv_module_from_gpu(self.model_context.model) + clear_memory(device_list=self.compress_context.device_list) + # Important change after v0.51, on cpu, we use rtn mode for layers in layer_names + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=[], last_cache_name=last_cache_name + ) + except Exception as e: + logger.error(cuda_error_msg) + raise + return all_inputs + + @torch.no_grad() + def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Save the inputs of block_name for calibration. + + This method temporarily replaces the forward method of the model to capture + the inputs passing through the specified block. It then calibrates the model + using a specified number of samples. Finally, it restores the original forward + method and returns the inputs for the specified block. + Args: + block_names (list): The names of the blocks for which inputs are to be saved. + layer_names (list):The names of the layers for which inputs are to be saved. + nsamples (int): The number of samples to use for calibration. + last_cache_name (str, optional): The name of the last layer to be cached, + we could break the forward in this layer to save time + + Returns: + dict: A dictionary containing the inputs for the specified block. + """ + if layer_names is None: + layer_names = [] + + if not self._post_init_done: + self.post_init() + + if hasattr(self, "quantizer") and hasattr(self.quantizer, "attention_mask"): + self.quantizer.attention_mask = [] + + self.inputs = {} + block_names = flatten_list(block_names) + self.to_cached_layers = block_names + layer_names + + tmp_dtype = None # TODO delete this as most model is not fp32 now + ## have bug if block name is not the first block + if (len(block_names) > 1 or len(layer_names) > 0) and self.compress_context.low_gpu_mem_usage: + tmp_dtype = self.model_context.model.dtype + if self.model_context.amp: + if self.model_context.model.dtype != self.model_context.model.dtype: + self.model_context.model = self.model_context.model.to(torch.bfloat16) + else: + self.model_context.model = self.model_context.model.to(torch.float32) ##model on cpu + + self.last_cache_name = _infer_last_cache_name(block_names, layer_names, last_cache_name) + self._cache_target_set = set(self.to_cached_layers) + self._cache_seen_targets = set() + calib_bs = self.quantizer.batch_size + self.hook_handles = [] + self._replace_forward() + try: + self.calib(nsamples, calib_bs) + finally: + # Use finally to recover_forward and delattr in case of that + # self.calib raises NotImplementedError, such as: flash_attn on CPU. + self.model_context.recover_forward() + for attr in ("last_cache_name", "_cache_target_set", "_cache_seen_targets", "to_cached_layers"): + if hasattr(self, attr): + delattr(self, attr) + # Release calibration dataloader to free tokenized sample tensors + if hasattr(self, "dataloader"): + del self.dataloader + res = self.inputs + if tmp_dtype is not None: + self.model_context.model = self.model_context.model.to(tmp_dtype) + + return res + + @torch.no_grad() + def calib(self, nsamples, bs): + """Perform calibration for quantization. + + This method calibrates the model for quantization by processing a specified + number of samples from the calibration dataset. It ensures that the data is + properly formatted and feeds it to the model. If the number of samples processed + is less than the specified number, it logs a warning. If no samples are processed, + it logs an error and exits. + Args: + nsamples (int): The number of samples to use for calibration. + bs (int): The number of samples to use for calibration + """ + from auto_round.calib_dataset import get_dataloader + + need_attention_mask = True + if isinstance(self.dataset, str): + need_attention_mask = False # all supported datasets does not use pad + dataset = self.dataset.replace(" ", "") ##remove all whitespaces + + # slow here + self.dataloader = get_dataloader( + self.model_context.tokenizer, + self.seqlen, + dataset, + self.seed, + bs, + self.nsamples, + ) + else: + self.dataloader = self.dataset + total_cnt = 0 + if self.dataloader.__class__.__name__ == "BatchEncoding": + self.dataloader = [self.dataloader.data] + + for data in self.dataloader: + if data.__class__.__name__ == "BatchEncoding": + data = data.data + if data is None: + continue + if isinstance(data, torch.Tensor): + input_ids = data.to(self.model.device) + data_new = input_ids + elif isinstance(data, str): + if self.model_context.tokenizer is None: + logger.error("please provide tokenizer for string input") + exit(-1) + data = self.model_context.tokenizer( + data, truncation=True, max_length=self.seqlen, return_tensors="pt" + ).data + data_new = {} + for key in data.keys(): + data_new[key] = data[key].to(self.model.device) + input_ids = data_new["input_ids"] + elif isinstance(data, tuple) or isinstance(data, list): + data_new = to_device(data, self.model.device) + input_ids = data_new[0] + else: + data_new = {} + for key in data.keys(): + data_new[key] = to_device(data[key], self.model.device) + if key == "images": + data_new[key] = to_dtype(data_new[key], self.model.dtype) + input_ids = data_new["input_ids"] + if input_ids.shape[-1] < self.seqlen: + continue + if need_attention_mask: + if ( + isinstance(data_new, dict) + and "attention_mask" in data_new + and data_new["attention_mask"] is not None + ): + new_attention_mask = data_new["attention_mask"] + elif ( + self.model_context.tokenizer is not None + and hasattr(self.model_context.tokenizer, "pad_token") + and self.model_context.tokenizer.pad_token is not None + ): + new_attention_mask = (input_ids != self.model_context.tokenizer.pad_token_id).to(torch.long) + else: + # Default all ones + new_attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + # For each sample, check if there are trailing repeated tokens + # If so, set the mask of the last token to 0 + batch_size, seq_len = input_ids.shape + for i in range(batch_size): + last_token = input_ids[i, -1] + # Check for trailing repeats + j = seq_len - 2 + repeated = False + while j >= 0 and input_ids[i, j] == last_token: + repeated = True + new_attention_mask[i, j] = 0 + j -= 1 + # If there was at least one repeat, set last token mask to 0 + if repeated: + new_attention_mask[i, -1] = 0 + + # Workaround: some models treat an all-1 attention mask as equivalent to None and + # will internally replace it with None for block inputs, which can cause tensor + # concatenation / shape-mismatch issues in downstream code. To avoid providing an + # all-1 mask, we force the last token in each sequence to be masked out (set to 0) + # so that the mask is never "all ones". This means the model will not attend to the + # last position, so the impact on accuracy is minimal as basically equivalent to dropping a single token + new_attention_mask[:, -1] = 0 + + if not hasattr(self.quantizer, "attention_mask"): + self.quantizer.attention_mask = [] + self.quantizer.attention_mask.extend(list(torch.split(new_attention_mask, 1, dim=0))) + else: + new_attention_mask = None + try: + kwargs = {"use_cache": False} + if new_attention_mask is not None and not (isinstance(data_new, dict) and "attention_mask" in data_new): + kwargs["attention_mask"] = new_attention_mask + + if isinstance(data_new, torch.Tensor): + self.model(data_new, **kwargs) + elif isinstance(data_new, tuple) or isinstance(data_new, list): + self.model(*data_new, **kwargs) + else: + self.model(**data_new, **kwargs) + except NotImplementedError as error: + error_msg = str(error) + # Raise NotImplementedError to fallback to CUDA device + if "flash_attn::" in error_msg and "CPU" in error_msg: + raise NotImplementedError( + "Could not run 'flash_attn::_flash_attn_varlen_forward'" + " with arguments from the 'CPU' backend." + ) + else: + pass + except RuntimeError as error: + error_msg = str(error) + if "The expanded size of the tensor" in str(error_msg) and "must match the existing size" in error_msg: + check_seqlen_compatible(self.seqlen, self.model_context.tokenizer, self.model) + logger.warning( + "When quantization encounters tensor shape mismatch error, " + "you can try to avoid it with batch_size=1" + ) + raise error + except Exception as error: + raise error + + total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 + if total_cnt >= nsamples: + break + if total_cnt == 0: + logger.error( + f"no data has been cached, please provide more data with sequence length " + f">={self.seqlen} in the dataset or decease the sequence length" + ) + exit(-1) + elif total_cnt < nsamples: + logger.warning_once( + f"An insufficient number of samples likely reduces the accuracy of the quantized model. " + f"Target samples count is {nsamples}, while valid samples count is {total_cnt}" + ) + + @torch.no_grad() + def _get_block_forward_func(self, name: str) -> Callable: + """Gets the forward function. + + Args: + name (str): The name of the function. + Returns: + function: The forward function. + """ + + def post_process_cache_data(batch_size, data, data_name): + """ + Processes store data for batch handling, reshaping if necessary. + + Args: + batch_size (int): The size of the batch. + data: The data value to store, potentially for caching. + data_name (str): Name of the data. + + Returns: + Processed data or None + """ + new_data = data + if data_name in self.model_context.shared_cache_keys: + return None + if batch_size <= 1: + return new_data + if "alibi" in data_name: + if isinstance(data, torch.Tensor): + alibi = data + alibi = alibi.reshape(batch_size, -1, alibi.shape[1], alibi.shape[2]) + new_data = alibi + return new_data + + def forward(m, hidden_states=None, *positional_inputs, **kwargs): + """Rewrite forward function, process and collect input data. + + Args: + hidden_states (torch.Tensor): The hidden states tensor. + *positional_inputs: Variable number of positional arguments. + **kwargs: Variable number of keyword arguments. + + Returns: + NotImplementedError: Getting the first layer inputs and then raise the error to save runtime. + """ + if name not in self.inputs: + self.inputs[name] = {} + init_cache(positional_inputs, self.inputs[name]) + + if self.quantizer.batch_dim is None: + self.quantizer.batch_dim = 0 + if hidden_states is not None and self.quantizer.batch_size > 1: + if hidden_states.shape[0] > self.quantizer.batch_size: + self.quantizer.batch_dim = 1 + if len(hidden_states.shape) > 1 and hidden_states.shape[1] > self.quantizer.batch_size: + logger.error( + "this model has not been supported, " + "please raise an issue in https://github.com/intel/auto-round/issues" + " or try to set the `batch_size` to 1 and " + "`gradient_accumulate_steps` to your current batch size." + ) + exit(-1) + + if hidden_states is not None: + kwargs["hidden_states"] = hidden_states + + for key in kwargs.keys(): + if ( + isinstance(kwargs[key], torch.Tensor) + or isinstance(kwargs[key], list) + or isinstance(kwargs[key], tuple) + ): + if ( + self.has_variable_block_shape + and name not in self.blocks_requiring_input_ids + and key == "hidden_states" + ): + continue + if key not in self.inputs[name].keys(): # initialization + data = to_device(kwargs[key], device=torch.device("cpu")) + if data is None or key in self.model_context.shared_cache_keys: + self.inputs[name][key] = data + continue + if self.quantizer.batch_size <= 1: + self.inputs[name][key] = [data] + else: + data = post_process_cache_data(self.quantizer.batch_size, data, key) + if isinstance(data, torch.Tensor): + self.inputs[name][key] = list(torch.split(data, 1, dim=self.quantizer.batch_dim)) + else: + self.inputs[name][key] = [data] + else: # append cache inputs + new_data = post_process_cache_data(self.quantizer.batch_size, kwargs[key], key) + if new_data is None: # shareable args or NoneType + continue + new_data = to_device(new_data, device=torch.device("cpu")) + if self.quantizer.batch_size <= 1: + self.inputs[name][key].append(new_data) + else: + if isinstance(new_data, torch.Tensor): + self.inputs[name][key].extend( + list(torch.split(new_data, 1, dim=self.quantizer.batch_dim)) + ) + else: + self.inputs[name][key].append(new_data) + elif isinstance(kwargs[key], (str, bool, type(None))): + if key not in self.inputs[name].keys(): + self.inputs[name][key] = kwargs[key] + else: + # Parameters not to be cached + if check_skippable_keywords(key): + logger.warning_once( + f"Please note that '{key}' key" " is not currently used in quantization fine-tuning." + ) + reset_params(self.inputs[name]) + + if self._should_stop_cache_forward(name): + raise NotImplementedError + else: + if hidden_states is not None: + kwargs.pop("hidden_states") + return m.orig_forward(hidden_states, *positional_inputs, **kwargs) + else: + # Currently only for Llama-3.2-Vision-Instruct Series + return m.orig_forward(*positional_inputs, **kwargs) + + return forward + + @torch.no_grad() + def _get_cache_data_hook_for_layer(self, name): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def cache_input_hook(module, inputs, outputs): + input = inputs + if isinstance(inputs, tuple) or isinstance(input, list): + input = inputs[0] + if name in self.inputs: + self.inputs[name].extend(list(torch.split(input.to("cpu"), 1, dim=0))) + else: + self.inputs[name] = list(torch.split(input.to("cpu"), 1, dim=0)) + + if self._should_stop_cache_forward(name): + raise NotImplementedError + + return cache_input_hook + + def _replace_forward(self): + """Replaces the forward function.""" + + def register_hook(n, m, hook_handles): + if n in self.to_cached_layers and type(m) not in SUPPORTED_LAYER_TYPES: ##block + m.orig_forward = m.forward + m.forward = partial(self._get_block_forward_func(n), m) + elif n in self.to_cached_layers: ##linear layer or conv1d layer + hook_func = self._get_cache_data_hook_for_layer(n) + hook_handle = m.register_forward_hook(hook_func) + hook_handles.append(hook_handle) + + self.model_context.replace_forward(register_hook) + + def _should_stop_cache_forward(self, name: str) -> bool: + """Determine whether current forward pass can stop after caching `name`.""" + if name == self.last_cache_name: + return True + + if self.last_cache_name is not None: + return False + + if not hasattr(self, "_cache_target_set") or not hasattr(self, "_cache_seen_targets"): + return False + + if name in self._cache_target_set: + self._cache_seen_targets.add(name) + + if not self._cache_target_set.issubset(self._cache_seen_targets): + return False + + # Lock the last cache name after the first full forward pass. + self.last_cache_name = name + return True + + def _preprocess_block_inputs(self, inputs, first_input_name="input_ids"): + input_ids, input_others = self._split_inputs(inputs, first_input_name) + clear_memory(device_list=self.compress_context.device_list) + tmp_dtype = self.model_context.amp_dtype if self.model_context.amp else torch.float32 + if input_ids is not None: + input_ids = to_device(input_ids, self.compress_context.cache_device) + input_ids = to_dtype(input_ids, tmp_dtype) + input_others = to_device(input_others, self.compress_context.cache_device) + + for key in input_others.keys(): + if isinstance(input_others[key], torch.Tensor) and ( + input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16 + ): + input_others[key] = input_others[key].to(tmp_dtype) + elif isinstance(input_others[key], list): + for i in range(len(input_others[key])): + to_dtype(input_others[key][i], tmp_dtype) + return input_ids, input_others + + def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[torch.Tensor, dict]: + if self.model_context.is_diffusion: + input_id_str = [key for key in inputs.keys() if "hidden_state" in key] + input_ids = {k: inputs.pop(k, None) for k in input_id_str} + input_others = inputs + return input_ids, input_others + input_ids = inputs.get(first_input_name, None) + inputs.pop(first_input_name, None) + input_others = inputs + return input_ids, input_others + + def normalize_decoding_layer_inputs_(self, decoding_layer_inputs: list[tuple[tuple[Any, dict[str, Any]]]]) -> None: + """Replay captured decoding-layer calls to populate ``self.inputs``. + + Converts the raw ``(args, kwargs)`` tuples captured by LLM-Compressor's + input hook into the ``self.inputs`` dict format expected by + :meth:`quantize_block`. The logic mirrors the old-arch implementation in + ``compressors/base.py``. + + Args: + decoding_layer_inputs: + A list of entries captured by a forward hook on the decoding layer. + Each element is a tuple whose first item is ``(args, kwargs)``. + """ + first_block_name = self.quant_block_list[0][0] + + class _FakeDecodingLayer(torch.nn.Module): + + def forward(self, *args, **kwargs): + return args, kwargs + + fake_layer = _FakeDecodingLayer() + fake_layer.orig_forward = fake_layer.forward + fake_layer.forward = partial(self._get_block_forward_func(first_block_name), fake_layer) + + self.inputs = {} + self.last_cache_name = None + for step_input in decoding_layer_inputs: + args, kwargs = step_input[0] + fake_layer(*args, **kwargs) + + def quantize_block( + self, + block: torch.nn.Module, + inputs: tuple, + q_input: Union[torch.Tensor, dict, None] = None, + device: Union[str, torch.device] = "cpu", + auto_offload: bool = True, + ): + """Quantize a single decoded block of the model (public API for LLM-Compressor). + + This method is the new-arch equivalent of the old ``BaseCompressor.quantize_block`` + (see ``compressors/base.py``). It is primarily consumed by LLM-Compressor: + https://github.com/vllm-project/llm-compressor/pull/1994 + + The method normalizes the raw decoding-layer inputs provided by LLM-Compressor, + runs the full infrastructure pipeline (device placement, act-max collection, + reference-output caching) for the given *block*, delegates the pure-algorithm + weight optimization to ``self.quantizer.quantize_block``, then returns the + quantized-block outputs. + + Args: + block: The transformer block (decoder layer) to quantize. + inputs: Raw decoding-layer inputs captured by LLM-Compressor's hook. + Format: list of ``((args, kwargs),)`` tuples as produced by the hook. + q_input: Optional quantized input from the previous block. ``None`` on + the first block. + device: Target device for quantization (e.g. ``"cuda:0"``). + auto_offload: When *True*, use the device-map-aware offloading path; + otherwise move ``block`` directly to ``device``. + + Returns: + tuple: ``(q_outputs, reference_output)`` where *q_outputs* is the + block's output after quantization (or ``None`` when + ``enable_quanted_input`` is ``False``), and *reference_output* is the + full-precision reference output collected before optimization. + """ + assert not self.mllm and not self.diffusion, ( + f"Currently, {self.__class__.__name__} does not support quantize_block " "for MLLM / diffusion models." + ) + + # Ensure post_init has been called (sets up model_context, compress_context, + # quantizer, layer_config, etc.). + if not self._post_init_done: + self.post_init() + + self.normalize_decoding_layer_inputs_(inputs) + block_inputs = self.inputs[self.quant_block_list[0][0]] + input_ids, input_others = self._preprocess_block_inputs(block_inputs, "hidden_states") + + # ── Infrastructure: materialize, dtype convert, device placement ────── + materialize_model_(block) + convert_module_to_hp_if_necessary(block, self.model_context.amp_dtype, device) + + if auto_offload: + if ( + is_auto_device_mapping(self.compress_context.device_map) + and len(self.compress_context.device_list) > 1 + and not self.model_context.is_diffusion + ): + from auto_round.utils.device import set_auto_device_map_for_block_with_tuning + + card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( + block, + self.compress_context.device_map, + input_ids, + self.compress_context.low_gpu_mem_usage, + self.quantizer.batch_size, + device, + ) + else: + block = block.to(device) + card_0_in_high_risk, loss_device = False, device + else: + card_0_in_high_risk, loss_device = False, device + + if len(self.compress_context.device_list) > 1 and auto_offload: + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + for n, m in block.named_modules(): + if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): + continue + add_hook_to_module(m, AlignDevicesHook(m.tuning_device, io_same_device=True), True) + + # ── Infrastructure: collect reference output and act_max ────────────── + bs = self.quantizer.batch_size * self.quantizer.infer_bs_coeff + if q_input is None: + hook_handles = self.quantizer._register_act_max_hook(block) + reference_output = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) + for h in hook_handles: + h.remove() + else: + reference_output = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) + hook_handles = self.quantizer._register_act_max_hook(block) + if hook_handles: + self.quantizer._get_block_outputs(block, q_input, input_others, bs, save_output=False) + for h in hook_handles: + h.remove() + if input_ids is not q_input: + clear_memory(input_ids, device_list=self.compress_context.device_list) + else: + clear_memory(device_list=self.compress_context.device_list) + input_ids = q_input + + # ── Pure algorithm: delegates to quantizer ──────────────────────────── + mid_iter_mem_check = self.compress_context.low_gpu_mem_usage and card_0_in_high_risk + self.quantizer.quantize_block( + block, + input_ids, + input_others, + reference_output, + loss_device=loss_device, + mid_iter_mem_check=mid_iter_mem_check, + ) + + # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── + if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): + set_amax_for_all_moe_layers(block, attr_name="act_max") + + # ── Collect quantized-block outputs ─────────────────────────────────── + if self.quantizer.enable_quanted_input: + q_outputs = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) + else: + q_outputs = None + + # ── Cleanup ─────────────────────────────────────────────────────────── + if len(self.compress_context.device_list) > 1: + accelerate.hooks.remove_hook_from_submodules(block) + mv_module_from_gpu(block) + + return q_outputs, reference_output + + def _quantize_blocks( + self, + model: torch.nn.Module, + inputs: dict, + block_names: list, + q_input: torch.Tensor = None, + nblocks: int = 1, + pbar: tqdm = None, + input_others_extra_blocks: dict = None, + ): + """Quantize and dequantize the weights of the specified blocks in the model. + + Args: + model: The PyTorch model to be quantized. + inputs: The input data for quantization. + block_names: The names of the blocks to be quantized and dequantized. + nblocks: The number of blocks to quantize and dequantize. + device: The device for quantization and dequantization. + + Returns: + None + """ + clear_memory(device_list=self.compress_context.device_list) + for n, m in model.named_parameters(): + m.requires_grad_(False) + + input_ids, input_others = self._preprocess_block_inputs(inputs) + + if pbar is None: + pbar = tqdm(range(0, len(block_names), nblocks)) + + for i in range(0, len(block_names), nblocks): + if input_others_extra_blocks and block_names[i] in input_others_extra_blocks: + input_others = input_others_extra_blocks[block_names[i]] + _, input_others = self._preprocess_block_inputs(input_others) + input_others_extra_blocks.pop(block_names[i]) + if i != 0: + pbar.update(1) + if nblocks == 1: + n = block_names[i] + pbar.set_description(f"Quantizing {n}") + m = get_module(model, n) + else: + names = block_names[i : min(i + nblocks, len(block_names))] + pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}") + modules = [get_module(model, n) for n in names] + m = WrapperMultiblock(modules) + + if self.compress_context.low_cpu_mem_usage: + if nblocks == 1: + self._offloader.reload(model, n) + else: + self._offloader.reload(model, names) + + block_name_or_names = n if nblocks == 1 else names + + # ── Infrastructure: materialize, dtype convert, device placement ── + materialize_model_(m) + convert_module_to_hp_if_necessary(m, self.model_context.amp_dtype, self.compress_context.device) + + if is_auto_device_mapping(self.compress_context.device_map) and len(self.compress_context.device_list) > 1: + from auto_round.utils.device import set_auto_device_map_for_block_with_tuning + + card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( + m, + self.compress_context.device_map, + input_ids, + self.compress_context.low_gpu_mem_usage, + self.quantizer.batch_size, + self.compress_context.device, + ) + else: + m = m.to(self.compress_context.device) + card_0_in_high_risk, loss_device = False, self.compress_context.device + + if len(self.compress_context.device_list) > 1: + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + for _n, _mod in m.named_modules(): + if len(list(_mod.children())) != 0 or not hasattr(_mod, "tuning_device"): + continue + add_hook_to_module(_mod, AlignDevicesHook(_mod.tuning_device, io_same_device=True), True) + + # ── Infrastructure: collect reference output and act_max ────────── + bs = self.quantizer.batch_size * self.quantizer.infer_bs_coeff + if q_input is None: + hook_handles = self.quantizer._register_act_max_hook(m) + reference_output = self.quantizer._get_block_outputs(m, input_ids, input_others, bs) + for h in hook_handles: + h.remove() + else: + reference_output = self.quantizer._get_block_outputs(m, input_ids, input_others, bs) + hook_handles = self.quantizer._register_act_max_hook(m) + if hook_handles: + self.quantizer._get_block_outputs(m, q_input, input_others, bs, save_output=False) + for h in hook_handles: + h.remove() + + # ── Infrastructure: swap q_input ────────────────────────────────── + if q_input is not None: + if input_ids is not q_input: + clear_memory(input_ids, device_list=self.compress_context.device_list) + else: + clear_memory(device_list=self.compress_context.device_list) + input_ids = q_input + + # ── Pure algorithm: delegates to quantizer ──────────────────────── + mid_iter_mem_check = self.compress_context.low_gpu_mem_usage and card_0_in_high_risk + self.quantizer.quantize_block( + m, + input_ids, + input_others, + reference_output, + loss_device=loss_device, + mid_iter_mem_check=mid_iter_mem_check, + ) + + # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── + if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): + set_amax_for_all_moe_layers(m, attr_name="act_max") + + # ── Infrastructure: collect q_outputs if needed ─────────────────── + if self.quantizer.enable_quanted_input: + q_input = self.quantizer._get_block_outputs(m, input_ids, input_others, bs) + else: + q_input = None + + # ── Infrastructure: hook removal, device cleanup, logging ───────── + if len(self.compress_context.device_list) > 1: + accelerate.hooks.remove_hook_from_submodules(m) + mv_module_from_gpu(m) + if self.enable_torch_compile: + torch._dynamo.reset() + self.quantizer._invalidate_block_forward_cache() + # Keep old-arch semantics: the next block's FP reference input comes + # from the current block's reference output, while q_input (when + # enabled) is only used as the quantized-input companion for the + # next block. + next_input_ids = reference_output + clear_memory( + input_ids if input_ids is not next_input_ids else None, device_list=self.compress_context.device_list + ) + memory_monitor.log_summary() + + # ── Infrastructure: immediate_pack / shard write ────────────────── + if self.compress_context.is_immediate_packing: + for _n, _mod in m.named_modules(): + if hasattr(_mod, "bits") and check_to_quantized(_mod): + from auto_round.compressors_new.utils import immediate_pack as _immediate_pack + + _immediate_pack(_mod.global_name, self.quantizer.layer_config) + + input_ids = next_input_ids + + if self.compress_context.is_immediate_saving: + self.shard_writer.write(m, is_finalize=False) + + if self.compress_context.low_cpu_mem_usage and not self.compress_context.is_immediate_saving: + if nblocks == 1: + self._offloader(model, n, overwrite=True) + else: + for name in names: + self._offloader(model, name, overwrite=True) + if pbar is not None: + pbar.update(1) + + if not self.compress_context.is_immediate_saving: + self.model = mv_module_from_gpu(self.model) + for n, m in self.model.named_modules(): + if hasattr(m, "name"): + delattr(m, "name") + + del q_input + del input_ids + del input_others + del inputs + + clear_memory(device_list=self.compress_context.device_list) + + def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: + """Quantize the model and return the quantized model along with layer configurations.The entry of AutoRound. + Returns: + The quantized model and layer configurations. + """ + self.post_init() + + # Reclaim heap fragmentation from init/post_init before the memory-intensive quantize loop. + gc.collect() + _force_trim_malloc() + + self._check_compatibility() + + if bool(self.quantizer.quant_block_list): + all_blocks = self.quantizer.quant_block_list + else: + all_blocks = get_block_names(self.model_context.model) + + if len(all_blocks) == 0: + logger.warning("could not find blocks, exit with original model") + return self.model_context.model, self.quantizer.layer_config + + layer_names = _get_quantized_layer_names_outside_blocks( + model=self.model_context.model, + layer_config=self.quantizer.layer_config, + supported_types=SUPPORTED_LAYER_TYPES, + quant_block_list=self.quantizer.quant_block_list, + ) + if not self.has_variable_block_shape: + to_cache_block_names = [block[0] for block in all_blocks] + else: + to_cache_block_names = flatten_list(all_blocks) + if len(layer_names) > 0: + logger.info( + "Starting to cache block inputs. This may be slow due to external block layers: %s", layer_names + ) + else: + logger.info("start to cache block inputs") + all_inputs = self.try_cache_inter_data_gpucpu( + to_cache_block_names, + self.nsamples, + layer_names, + ) + self.inputs = all_inputs + is_quantized_embedding = self._quantize_embedding_layer() + clear_memory(device_list=self.compress_context.device_list) + all_q_inputs = None + if is_quantized_embedding: + all_inputs = copy.deepcopy(self.inputs) + clear_memory(self.inputs, device_list=self.compress_context.device_list) + all_q_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names) + # Remove accelerate dispatch hooks before moving parameters. + # hf_device_map is kept for reference but hooks are no longer needed. + if hasattr(self.model_context.model, "hf_device_map") and len(self.model_context.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules(self.model_context.model) + self.model_context.model = mv_module_from_gpu(self.model_context.model) + clear_memory(device_list=self.compress_context.device_list) + logger.info("caching done") + if self.compress_context.low_cpu_mem_usage: + if self.model_context.is_model_patched and not self.compress_context.is_immediate_saving: + self._offloader(self.model_context.model, all_blocks, clear_memory=True, device_list=self.device_list) + if not self._offloader.enabled: + self.compress_context.low_cpu_mem_usage = False + else: + self.compress_context.low_cpu_mem_usage = False + if len(all_blocks) > 1: + pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) + else: + pbar = tqdm(range(0, len(all_blocks[0]), self.nblocks)) # move the alg warning outside pbar + + start_time = time.time() + for block_names in all_blocks: + inputs = all_inputs[block_names[0]] + all_inputs.pop(block_names[0]) + q_inputs = None + if all_q_inputs is not None: + q_inputs = all_q_inputs[block_names[0]] + all_q_inputs.pop(block_names[0]) + + inputs, q_inputs = _update_inputs(inputs, q_inputs) + + clear_memory(self.inputs, device_list=self.compress_context.device_list) + + if "input_ids" in inputs.keys(): + total_samples = len(inputs["input_ids"]) + if total_samples < self.quantizer.batch_size: + self.quantizer.batch_size = total_samples + logger.warning(f"force the train batch size to {total_samples}") + + self._quantize_blocks( + self.model_context.model, + inputs, + block_names, + q_input=q_inputs if q_inputs is not None else None, + nblocks=self.nblocks, + pbar=pbar, + input_others_extra_blocks=all_inputs, + ) + if self.compress_context.is_immediate_packing and len(self.formats) != 1: + raise ValueError( + f"Expected exactly one packing format when 'immediate_packing' is True, " + f"but got {len(self.formats)} formats." + ) + pbar.set_description("Quantizing done") + pbar.close() + if self.compress_context.low_cpu_mem_usage: + self._offloader.reload(self.model_context.model) + self._quantize_layers(layer_names, all_inputs) + + convert_module_to_hp_if_necessary( + self.model_context.model, self.model_context.amp_dtype, self.compress_context.device, to_cpu=True + ) + if self.compress_context.is_immediate_saving: + self.shard_writer.write(is_finalize=True) + + end_time = time.time() + cost_time = end_time - start_time + logger.info(f"quantization tuning time {cost_time}") + + # Dump a summary + quantized_layers = [] + unquantized_layers = [] + for n, m in self.model_context.model.named_modules(): + if isinstance(m, tuple(SUPPORTED_LAYER_TYPES)): + if check_to_quantized(m): + quantized_layers.append(n) + else: + unquantized_layers.append(n) + elif hasattr(m, "scales") or hasattr(m, "scale"): # packing_immediately + quantized_layers.append(n) + summary_info = ( + f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" + ) + if len(unquantized_layers) > 0: + compressed_unquantized_layers = compress_layer_names(unquantized_layers) + summary_info += f", unquantized layers: {compressed_unquantized_layers}" + logger.info(summary_info) + + self.model_context.quantized = True + return self.model_context.model, self.quantizer.layer_config + + def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: + """Quantizes specified layers based on inputs and configuration. + + Args: + layer_names (list): list of layer names to quantize. + layer_inputs (dict): Dictionary mapping layer names to input data. + + Returns: + None + """ + # TODO currently we take all the layers outside blocks as post block layers which is not optimal + # if there is no input for layer, we use rtn + + for layer_name in copy.deepcopy(layer_names): + if layer_name not in layer_inputs: + if self.act_bits < 16 and not self.act_dynamic: + if "lm_head" in layer_name: + logger.warning_once( + "Static activation quantization for lm_head is not fully supported yet. " + "If lm_head calibration inputs are missing, activation scale may fall back to unit scale " + "or quantization may be skipped." + ) + # Activation quantization requires collected inputs + msg_prefix = ( + f"Activation max hook for layer '{layer_name}' is unavailable due to " + f"insufficient collected inputs. " + ) + if "fp8_e5m2" in self.act_data_type: + logger.warning(msg_prefix + "Please notes that unit scale is used for this layer.") + else: + logger.warning( + msg_prefix + "Static activation quantization is not supported or ineffective, " + "Skipping quantization for this layer." + ) + layer_names.remove(layer_name) + continue + logger.info(f"using rtn to quantize {layer_name}") + from auto_round.data_type import QUANT_FUNC_WITH_DTYPE + + layer = get_module(self.model, layer_name) + layer = layer.to(self.compress_context.device) + layer = convert_module_to_hp_if_necessary( + layer, self.model_context.amp_dtype, self.compress_context.device + ) + set_module(self.model, layer_name, layer) + + wrapper_layer = WrapperLinear( + layer, + enable_round_tuning=False, + enable_minmax_tuning=False, + enable_norm_bias_tuning=False, + enable_torch_compile=self.enable_torch_compile, + device=self.compress_context.device, + disable_opt_rtn=self.disable_opt_rtn, + ) + new_layer = wrapper_layer.unwrapper({}) + set_module(self.model, layer_name, new_layer) + layer.cpu() + layer_names.remove(layer_name) + if len(layer_names) == 0: + memory_monitor.update() + memory_monitor.log_summary() + return + q_layer_inputs = None + enable_quanted_input = self.enable_quanted_input + has_gguf = False + + if hasattr(self, "formats") and self.formats is not None: + has_gguf = any(format_.is_gguf() for format_ in self.formats) + if has_gguf and self.compress_context.is_immediate_packing: + enable_quanted_input = False + + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1 and enable_quanted_input: + dispatch_model(self.model, self.model.hf_device_map) + + if enable_quanted_input: + logger.info("starting to cache layer inputs for %s, this may be quite slow ", layer_names) + q_layer_inputs = self.try_cache_inter_data_gpucpu([], self.nsamples, layer_names=layer_names) + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules( + self.model + ) # self.model.hf_device_map has not been changed + if not self.compress_context.is_immediate_saving: + self.model = mv_module_from_gpu(self.model) + clear_memory(device_list=self.compress_context.device_list) + quant_layer = self.quantizer.quantize_layer_outside_block + for layer_name in layer_names: + layer_input = layer_inputs[layer_name] + layer_input = to_device(layer_input, self.compress_context.cache_device) + q_layer_input = q_layer_inputs.get(layer_name, None) if q_layer_inputs is not None else None + q_layer_input = to_device(q_layer_input, self.compress_context.cache_device) + quant_layer(layer_name, layer_input, q_layer_input, device=self.compress_context.device) + if self.compress_context.is_immediate_packing: + immediate_pack(layer_name, self.quantizer.layer_config) + + if self.compress_context.is_immediate_saving: + m = get_module(self.model, layer_name) + self.shard_writer.write(m, name=layer_name, is_finalize=False) + del layer_input + clear_memory(q_layer_input, device_list=self.compress_context.device_list) + memory_monitor.log_summary() + + def _check_compatibility(self) -> None: + """Checks compatibility of the configurations and model.""" + if ( + self.seqlen is not None + and hasattr(self.model_context.model, "config") + and hasattr(self.model_context.model.config, "max_position_embeddings") + ): + if self.model_context.model.config.max_position_embeddings < self.seqlen: + logger.warning( + f"Change sequence length to {self.model_context.model.config.max_position_embeddings} " + "due to the limitation of max_position_embeddings" + ) + self.seqlen = min(self.seqlen, self.model_context.model.config.max_position_embeddings) + + if self.seqlen is not None and hasattr(self.model_context.tokenizer, "model_max_length"): + if self.model_context.tokenizer.model_max_length < self.seqlen: + logger.warning( + f"Change sequence length to {self.model_context.tokenizer.model_max_length} " + "due to the limitation of model_max_length. " + "You can also try to increase the model_max_length to avoid this issue." + ) + self.seqlen = min(self.seqlen, self.model_context.tokenizer.model_max_length) + + if self.group_size == 0 and "fp8" not in self.data_type: + logger.warning("`group_size==0` is not supported for data_type other than fp8 ") + + if ( + self.bits <= 2 + and (self.iters < 1000 or not getattr(self.quantize_config, "enable_alg_ext", False)) + and self.super_group_size is None + ): + logger.warning( + "for bits <= 2, it is recommended to enable `auto-round-best` " "and turn on `--enable_alg_ext` " + ) + + +class CalibratedRTNCompressor(CalibCompressor): + """CalibCompressor variant for iters=0 RTN that needs calibration data. + + Handles two cases that require forward passes through the model: + - Weight quantization with imatrix (importance-matrix statistics for + improved RTN accuracy on INT / weight-only schemes). + - Activation quantization with static scales (e.g. NVFP4, FP8_STATIC) + where per-tensor or per-channel scale factors must be collected before + the actual quantization step. + + Both cases use OptimizedRTNQuantizer and need a calibration dataset, + which is why they cannot be handled by the zero-shot (no-data) path. + """ + + need_calib: bool = True + + def __init__( + self, + config: AlgConfig, + model: torch.nn.Module, + **kwargs, + ): + kwargs["iters"] = 0 + super().__init__( + config, + model, + **kwargs, + ) + + def _quantize_via_rtn_blockwise(self) -> None: + """Quantize model layers block by block using cached inputs and imatrix.""" + + all_blocks = self.quantizer.quant_block_list or get_block_names(self.model) + if not all_blocks: + raise ValueError("Could not find any blocks. Check the model or quant_block_list.") + + if not self.has_variable_block_shape: + to_cache_block_names = [block[0] for block in all_blocks] + else: + to_cache_block_names = flatten_list(all_blocks) + layer_names = _get_quantized_layer_names_outside_blocks( + model=self.model_context.model, + layer_config=self.quantizer.layer_config, + supported_types=SUPPORTED_LAYER_TYPES, + quant_block_list=self.quantizer.quant_block_list, + ) + if ( + self.quantize_config.is_act_quantize + and (not self.quantize_config.act_dynamic or len(layer_names) > 0) + or self.has_variable_block_shape + ): + if len(layer_names) > 0: + logger.warning( + "quantize layers outside blocks for static activation quantizaiton" + " will significantly increase calibration time" + ) + all_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names) + else: + all_inputs = self.cache_inter_data(to_cache_block_names, self.nsamples) + + # Clear hooks for multi-GPU setups + if hasattr(self.model_context.model, "hf_device_map") and len(self.model_context.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules(self.model_context.model) + + pbar = tqdm(range(sum(len(block) for block in all_blocks))) + + for block_names in all_blocks: + first_block = block_names[0] + inputs = all_inputs.pop(first_block) + input_keys = [k for k in inputs if k.startswith("hidden_state")] + if len(input_keys) != 1: + raise RuntimeError( + "hidden_states arg mismatch. Please file an issue at https://github.com/intel/auto-round/issues" + ) + inputs["input_ids"] = inputs.pop(input_keys[0]) + + clear_memory(self.inputs, device_list=self.compress_context.device_list) + + total_samples = len(inputs["input_ids"]) + if total_samples < self.quantize_config.batch_size: + self.quantize_config.batch_size = total_samples + logger.warning(f"Forcing batch size to {total_samples}") + + tmp_dtype = self.model_context.amp_dtype if self.model_context.amp else torch.float32 + + input_ids = to_device(inputs.pop("input_ids"), self.compress_context.cache_device) + input_ids = [id_.to(tmp_dtype) for id_ in input_ids] + + def process_input_others(input_others): + input_others = to_device(input_others, self.compress_context.cache_device) + for key, val in input_others.items(): + if isinstance(val, torch.Tensor) and val.dtype in (torch.float16, torch.bfloat16): + input_others[key] = val.to(tmp_dtype) + elif isinstance(val, list): + input_others[key] = [to_dtype(v, tmp_dtype) for v in val] + return input_others + + input_others = inputs + input_others = process_input_others(input_others) + for block_name in block_names: + if block_name in all_inputs.keys(): + input_others = all_inputs[block_name] + input_others = process_input_others(input_others) + all_inputs.pop(block_name) + pbar.set_description(f"Quantizing {block_name}") + block = get_module(self.model_context.model, block_name) + + # ── Infrastructure: materialize, dtype convert, device placement ── + materialize_model_(block) + block.to("cpu") + block = convert_module_to_hp_if_necessary( + block, dtype=self.model_context.amp_dtype, device=self.compress_context.device + ) + if ( + is_auto_device_mapping(self.compress_context.device_map) + and len(self.compress_context.device_list) > 1 + ): + from auto_round.utils.device import set_auto_device_map_for_block_with_tuning + + set_auto_device_map_for_block_with_tuning( + block, + self.compress_context.device_map, + input_ids, + self.compress_context.low_gpu_mem_usage, + self.quantizer.batch_size, + self.compress_context.device, + ) + if len(self.compress_context.device_list) > 1: + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + for _, _mod in block.named_modules(): + if len(list(_mod.children())) != 0 or not hasattr(_mod, "tuning_device"): + continue + add_hook_to_module(_mod, AlignDevicesHook(_mod.tuning_device, io_same_device=True), True) + else: + block = block.to(self.compress_context.device) + + # ── Infrastructure: register act_max hook and run forward pass ── + hook_handles = self.quantizer._register_act_max_hook(block) + input_ids = self.quantizer._get_block_outputs( + block, + input_ids, + input_others, + self.quantizer.batch_size * self.quantizer.infer_bs_coeff, + ) + for h in hook_handles: + h.remove() + + if len(self.compress_context.device_list) > 1: + accelerate.hooks.remove_hook_from_submodules(block) + + if self.compress_context.low_gpu_mem_usage: + block.to("cpu") + self.compress_context.clear_memory() + + # ── Pure algorithm ──────────────────────────────────────────── + self.quantizer.quantize_block(block) + + # ── Infrastructure: cleanup ─────────────────────────────────── + mv_module_from_gpu(block) + + if self.compress_context.low_cpu_mem_usage and not self.compress_context.is_immediate_saving: + self._offloader(self.model_context.model, block_name) + if block_name == block_names[-1]: + clear_memory(input_ids, device_list=self.compress_context.device_list) + else: + clear_memory(device_list=self.compress_context.device_list) + + memory_monitor.log_summary() + pbar.update(1) + pbar.close() + # Process remaining layers not in blocks + # Collect names of quantizable layers not belonging to any block + remain_layer_names = [] + block_name_set = set(name for block in all_blocks for name in block) + for n, m in self.model_context.model.named_modules(): + if not check_to_quantized(m): + continue + # Skip if this layer is part of any block (by prefix match) + if any(n == block_name or n.startswith(f"{block_name}.") for block_name in block_name_set): + continue + remain_layer_names.append(n) + + for name in remain_layer_names: + dtype = None + if self.super_group_size is not None: + dtype = torch.float32 + self.quantizer.quantize_layer_outside_block(name, dtype=dtype) + # clear_memory(device_list=self.compress_context.device_list) + # if self.compress_context.is_immediate_saving: + # shard_writer(self, is_finalize=True) + + def _quant_rtn_with_imatrix(self) -> None: + """Performs RTN quantization using input activation statistics (imatrix). + + This method accumulates per-channel second-moment activation statistics (imatrix) + via forward hooks and uses them to perform RTN quantization. If CUDA memory runs out, + it falls back to CPU-based blockwise quantization. + + Returns: + None + """ + logger.info("start to compute imatrix") + + # Load dataset + from auto_round.calib_dataset import get_dataloader + + if isinstance(self.dataset, str): + if self.model_context.tokenizer is None: + raise ValueError("A tokenizer must be set for the model when using a dataset string.") + dataset_name = self.dataset.replace(" ", "") + self.dataloader = get_dataloader( + self.model_context.tokenizer, + self.seqlen, + dataset_name, + self.seed, + self.quantize_config.batch_size, + self.nsamples, + ) + else: + self.dataloader = self.dataset + + model = self.model_context.model + + # Dispatch multi-GPU model if necessary + if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: + dispatch_model(model, model.hf_device_map) + + def register_act_hook(model): + """Registers hooks to accumulate activation squared norms into `imatrix`.""" + + def get_imatrix_hook(module, input, output): + input = input[0] if isinstance(input, (tuple, list)) else input + flattened = input.reshape(-1, input.shape[-1]).to(torch.float32) + squared = torch.sum(torch.pow(flattened, 2), dim=0).to(torch.float32) + + if not hasattr(module, "imatrix"): + module.imatrix = squared + module.imatrix_cnt = input.shape[0] + else: + module.imatrix += squared.to(module.imatrix.device) + module.imatrix_cnt += input.shape[0] + + hook_handles = [] + for name, module in model.named_modules(): + if type(module) in SUPPORTED_LAYER_TYPES and check_to_quantized(module): + hook = module.register_forward_hook(get_imatrix_hook) + hook_handles.append(hook) + return hook_handles + + hooks = register_act_hook(model) + + try: + if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: + import accelerate + + accelerate.hooks.remove_hook_from_submodules(model) + safe_to_cpu_(model) + clear_memory(device_list=self.compress_context.device_list) + self._quantize_via_rtn_blockwise() + except torch.OutOfMemoryError: + cuda_error_msg = traceback.format_exc() + try: + logger.error(cuda_error_msg) + # Final fallback: warn and use CPU-only quantization + logger.warning( + "Fallback to CPU. " + "Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`." + ) + safe_to_cpu_(model) + clear_memory(device_list=self.compress_context.device_list) + if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: + import accelerate + + accelerate.hooks.remove_hook_from_submodules(model) + + orig_device = self.compress_context.device + self.compress_context.device = "cpu" + self._quantize_via_rtn_blockwise() + self.compress_context.device = orig_device + except Exception as e: + raise + finally: + # Always remove hooks + for hook in hooks: + hook.remove() + + def quantize(self): + """Quantize all modules in the model using RTN (Round-To-Nearest) strategy. + + If the target format includes GGUF with `k`, and optimized RTN is enabled, + blockwise quantization with input caching and imatrix is used. + + Returns: + tuple[nn.Module, Dict[str, Any]]: The quantized model and the layer configuration. + """ + # post_init must be called OUTSIDE @torch.inference_mode() because + # AutoScheme delta-loss selection requires autograd (backward pass). + self.post_init() + return self._quantize_impl() + + # Use no_grad instead of inference_mode + # https://github.com/intel/auto-round/issues/1620 + @torch.no_grad() + def _quantize_impl(self): + + formats = getattr(self, "formats", None) or [] + if not (any(fmt.is_gguf() for fmt in formats) or self.super_bits is not None): + self._quantize_embedding_layer() # leave to gguf itself to handle + + # Release memory + clear_memory(device_list=self.compress_context.device_list) + + enable_imatrix = False + if not getattr(self, "disable_opt_rtn", True): + formats = getattr(self, "formats", None) or [] + has_gguf_k = ( + any(fmt.is_gguf() and "k" in fmt.output_format for fmt in formats) or self.super_bits is not None + ) + if has_gguf_k: + enable_imatrix = True + elif self.data_type == "int" and self.sym and self.bits < 8: + enable_imatrix = True + + if enable_imatrix: + self._quant_rtn_with_imatrix() + else: + self._quantize_via_rtn_blockwise() + + convert_module_to_hp_if_necessary( + self.model_context.model, + self.model_context.amp_dtype, + self.compress_context.device, + ) + if self.compress_context.low_cpu_mem_usage: + self._offloader.reload(self.model_context.model) + if self.compress_context.is_immediate_saving: + self.shard_writer.write(is_finalize=True) + + self.model_context.quantized = True + return self.model_context.model, self.quantizer.layer_config diff --git a/auto_round/compressors_new/config.py b/auto_round/compressors_new/config.py new file mode 100644 index 000000000..dd028b6cd --- /dev/null +++ b/auto_round/compressors_new/config.py @@ -0,0 +1,296 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, fields +from typing import Any, Callable, Optional, Union + +import torch + + +class ExtraConfig: + """Class for extra or legacy configs.""" + + _model_config = None + _scheme_config = None + _tuning_config = None + _mllm_config = None + _diffusion_config = None + + def __init__( + self, + # tuning + amp: bool = True, + disable_opt_rtn: bool | None = None, + enable_alg_ext: bool = False, + enable_minmax_tuning: bool = True, + enable_norm_bias_tuning: bool = False, + enable_quanted_input: bool = True, + enable_deterministic_algorithms: bool = False, + lr: float = None, + lr_scheduler: Callable = None, + minmax_lr: float = None, + nblocks: int = 1, + to_quant_block_names: Union[str, list, None] = None, + scale_dtype: str = "fp16", + # scheme + bits: int = None, + group_size: int = None, + sym: bool = None, + data_type: str = None, + act_bits: int = None, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = None, + super_bits: int = None, + super_group_size: int = None, + static_kv_dtype: Union[str, torch.dtype] = None, + quant_lm_head: bool = False, + ignore_layers: str = None, + # mllm + processor: Callable = None, + image_processor: Callable = None, + quant_nontext_module: bool = False, + extra_data_dir: str = None, + template: str = None, + # diffusion + guidance_scale: float = 7.5, + num_inference_steps: int = 50, + generator_seed: int = None, + ): + """Initialize + + Args: + amp (bool): Whether to use automatic mixed precision (default is True). + disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to True. + enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False. + enable_minmax_tuning (bool, optional): Enable weight min-max tuning. Defaults to True. + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning. + enable_quanted_input (bool): Whether to use quantized input data (default is True). + enable_deterministic_algorithms (bool): Whether to use deterministic_algorithms. + lr (float): The learning rate (default is 0.005). + lr_scheduler: The learning rate scheduler to be used. + minmax_lr (float): The learning rate for min-max tuning (default is None). + nblocks (int): Number of blocks (default is 1). + quant_lm_head (bool): Whether to quant lm_head. + to_quant_block_names (str|list): Names of quantitative blocks, please use commas to separate them. + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + bits (int, optional): Weight quantization bits. Defaults to 4. + group_size (int, optional): Weight quantization group size. Defaults to 128. + sym (bool, optional): Symmetric weight quantization. Defaults to True. + data_type (str, optional): Weight data type string, e.g., "int". Defaults to "int". + act_bits (int, optional): Activation quantization bits. Defaults to 16. + act_group_size (int, optional): Activation group size. Defaults to None. + act_sym (bool, optional): Symmetric activation quantization. Defaults to None. + act_data_type (str, optional): Activation data type; inherits weight dtype if None and act_bits < 16. + act_dynamic (bool, optional): Dynamic activation quantization. Defaults to True. + super_bits (int): number of scale and mins quant bits for double quant. + super_group_size (int): the number of super group size when use double quant. + static_kv_dtype (str): The data type of kv-cache to be used. + processor: Any multi-modal model will require an object to encode or + decode the data that groups several modalities (among text, vision and audio). + image_processor: Image processor for special model like llava. + quant_nontext_module: Whether to quantize nontext module. + extra_data_dir: The path of extra data such as images, audio and videos. + template: The path or name of template used to specify process for different MLLMs. + guidance_scale (float): Control how much the image generation process follows the text prompt. + The more it is, the more closely it follows the prompt (default is 7.5). + num_inference_steps (int): The reference number of denoising steps (default is 50). + generator_seed (int): A seed that controls the initial noise for image generation (default is None). + """ + self.tuning_config = TuningExtraConfig( + amp=amp, + disable_opt_rtn=disable_opt_rtn, + enable_alg_ext=enable_alg_ext, + enable_minmax_tuning=enable_minmax_tuning, + enable_norm_bias_tuning=enable_norm_bias_tuning, + enable_quanted_input=enable_quanted_input, + enable_deterministic_algorithms=enable_deterministic_algorithms, + lr=lr, + lr_scheduler=lr_scheduler, + minmax_lr=minmax_lr, + nblocks=nblocks, + to_quant_block_names=to_quant_block_names, + scale_dtype=scale_dtype, + ) + self.scheme_config = SchemeExtraConfig( + bits=bits, + group_size=group_size, + sym=sym, + data_type=data_type, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + act_data_type=act_data_type, + act_dynamic=act_dynamic, + super_bits=super_bits, + super_group_size=super_group_size, + static_kv_dtype=static_kv_dtype, + quant_lm_head=quant_lm_head, + ignore_layers=ignore_layers, + ) + self.mllm_config = MLLMExtraConfig( + processor=processor, + image_processor=image_processor, + quant_nontext_module=quant_nontext_module, + extra_data_dir=extra_data_dir, + template=template, + ) + self.diffusion_config = DiffusionExtraConfig( + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator_seed=generator_seed, + ) + + @property + def tuning_config(self): + return self._tuning_config + + @tuning_config.setter + def tuning_config(self, config: TuningExtraConfig): + assert isinstance( + config, TuningExtraConfig + ), f"tuning_config should be ModelExtraConfig, but got {config.__class__.__name__}" + self._tuning_config = config + + @property + def scheme_config(self): + return self._scheme_config + + @scheme_config.setter + def scheme_config(self, config: SchemeExtraConfig): + assert isinstance( + config, SchemeExtraConfig + ), f"scheme_config should be SchemeExtraConfig, but got {config.__class__.__name__}" + self._scheme_config = config + + @property + def mllm_config(self): + return self._mllm_config + + @mllm_config.setter + def mllm_config(self, config: MLLMExtraConfig): + if config is None: + self._mllm_config = None + else: + assert isinstance( + config, MLLMExtraConfig + ), f"mllm_config should be MLLMExtraConfig, but got {config.__class__.__name__}" + self._mllm_config = config + + @property + def diffusion_config(self): + return self._diffusion_config + + @diffusion_config.setter + def diffusion_config(self, config: DiffusionExtraConfig): + if config is None: + self._diffusion_config = None + else: + assert isinstance( + config, DiffusionExtraConfig + ), f"diffusion_config should be DiffusionExtraConfig, but got {config.__class__.__name__}" + self._diffusion_config = config + + def to_dict(self): + output_dict = {} + for config in self.__dict__.values(): + if config: + output_dict.update(config.to_dict()) + return output_dict + + +@dataclass +class BaseExtraConfig: + + @classmethod + def get_attributes(cls: "BaseExtraConfig") -> list[str]: + return [field.name for field in fields(cls)] + + def __getitem__(self, key: str): + if key not in self.get_attributes(): + raise KeyError(f"{key} is not a valid attribute") + return getattr(self, key) + + def __setitem__(self, key: str, value: None | int | str): + if key not in self.get_attributes(): + raise KeyError(f"{key} is not a valid attribute") + setattr(self, key, value) + + def __contains__(self, item): + return item in self.get_attributes() + + def to_dict(self): + return self.__dict__ + + def is_default(self): + for field in fields(self): + default_value = field.default + current_value = getattr(self, field.name) + if current_value != default_value: + return False + return True + + +@dataclass +class TuningExtraConfig(BaseExtraConfig): + amp: bool = True + disable_opt_rtn: bool | None = None + enable_alg_ext: bool = False + enable_minmax_tuning: bool = True + enable_norm_bias_tuning: bool = False + enable_quanted_input: bool = True + enable_deterministic_algorithms: bool = False + lr: float = None + lr_scheduler: Callable = None + minmax_lr: float = None + nblocks: int = 1 + to_quant_block_names: Union[str, list, None] = None + scale_dtype: str = "fp16" + + +@dataclass +class SchemeExtraConfig(BaseExtraConfig): + bits: int = None + group_size: int = None + sym: bool = None + data_type: str = None + act_bits: int = None + act_group_size: int = None + act_sym: bool = None + act_data_type: str = None + act_dynamic: bool = None + super_bits: int = None + super_group_size: int = None + static_kv_dtype: Union[str, torch.dtype] = None + static_attention_dtype: Union[str, torch.dtype] = None + quant_lm_head: bool = False + ignore_layers: str = None + + +@dataclass +class MLLMExtraConfig(BaseExtraConfig): + processor: Callable = None + image_processor: Callable = None + quant_nontext_module: bool = False + extra_data_dir: str = None + template: str = None + + +@dataclass +class DiffusionExtraConfig(BaseExtraConfig): + guidance_scale: float = 7.5 + num_inference_steps: int = 50 + generator_seed: int = None diff --git a/auto_round/compressors_new/diffusion_mixin.py b/auto_round/compressors_new/diffusion_mixin.py new file mode 100644 index 000000000..515839c1f --- /dev/null +++ b/auto_round/compressors_new/diffusion_mixin.py @@ -0,0 +1,271 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Union + +import torch +from tqdm import tqdm + +from auto_round.logger import logger +from auto_round.utils.model import wrap_block_forward_positional_to_kwargs + + +class DiffusionMixin: + """Diffusion-specific functionality mixin. + + This mixin adds diffusion model-specific functionality to any compressor + (CalibCompressor, ZeroShotCompressor, ImatrixCompressor, etc). It handles + diffusion models (like Stable Diffusion, FLUX) that require special pipeline + handling and data generation logic. + + Can be combined with: + - CalibCompressor (for AutoRound with calibration) + - ImatrixCompressor (for RTN with importance matrix) + - ZeroShotCompressor (for basic RTN) + + Diffusion-specific parameters: + guidance_scale: Control how much image generation follows text prompt + num_inference_steps: Reference number of denoising steps + generator_seed: Seed for initial noise generation + + Design note: + ``ModelContext._load_model()`` loads the diffusion pipeline and sets + ``model_context.pipe`` and ``model_context.model`` (the unet/transformer). + This mixin reads ``self.model_context.pipe`` directly during calibration and + saving so that ``model_context`` remains the single source of truth. + """ + + def __init__(self, *args, guidance_scale=7.5, num_inference_steps=50, generator_seed=None, **kwargs): + # Store diffusion-specific attributes + self.guidance_scale = guidance_scale + self.num_inference_steps = num_inference_steps + self.generator_seed = generator_seed + + # Mirror old-arch DiffusionCompressor.__init__: when iters > 0, diffusion calibration + # cannot use batch_size > 1 for non-text modules; fold the extra batch into + # gradient_accumulate_steps so the effective sample count is unchanged. + # The authoritative batch_size lives on the AlgConfig (args[0]); kwargs may also + # carry it from AutoRoundCompatible. Patch BOTH (same pattern as MLLMMixin). + iters = kwargs.get("iters", None) + _alg_cfg = args[0] if args else None + if iters is None and _alg_cfg is not None: + cfgs = _alg_cfg if isinstance(_alg_cfg, list) else [_alg_cfg] + for cfg in cfgs: + if hasattr(cfg, "iters") and cfg.iters is not None: + iters = cfg.iters + break + if iters is None: + iters = 200 + + if iters > 0: + batch_size = kwargs.get("batch_size", None) + if batch_size is None and _alg_cfg is not None: + cfgs = _alg_cfg if isinstance(_alg_cfg, list) else [_alg_cfg] + for cfg in cfgs: + if hasattr(cfg, "batch_size") and cfg.batch_size is not None: + batch_size = cfg.batch_size + break + if batch_size is not None and batch_size != 1: + grad_acc = kwargs.get("gradient_accumulate_steps", 1) + if _alg_cfg is not None: + cfgs = _alg_cfg if isinstance(_alg_cfg, list) else [_alg_cfg] + for cfg in cfgs: + if hasattr(cfg, "gradient_accumulate_steps") and cfg.gradient_accumulate_steps is not None: + grad_acc = cfg.gradient_accumulate_steps + break + new_grad_acc = batch_size * grad_acc + kwargs["gradient_accumulate_steps"] = new_grad_acc + kwargs["batch_size"] = 1 + if _alg_cfg is not None: + cfgs = _alg_cfg if isinstance(_alg_cfg, list) else [_alg_cfg] + for cfg in cfgs: + if hasattr(cfg, "batch_size"): + cfg.batch_size = 1 + if hasattr(cfg, "gradient_accumulate_steps"): + cfg.gradient_accumulate_steps = new_grad_acc + logger.warning( + f"reset batch_size({batch_size}) to 1 and " + f"gradient_accumulate_steps to {new_grad_acc} " + f"because batch_size={batch_size} cannot be used for calibrating non-text modules." + ) + + # Call parent class __init__ (will be CalibCompressor, ImatrixCompressor, etc) + super().__init__(*args, **kwargs) + + # Mirror old-arch DiffusionCompressor._align_device_and_dtype: unconditionally + # cast the full diffusion pipeline (VAE, text encoder, etc.) to the transformer's + # dtype so that calibration's pipe(...) call doesn't crash with dtype mismatches + # when the transformer is force-cast to bf16 for activation quantization. + # Note: pipe.dtype only reflects the primary component, so an equality check would + # miss mixed-dtype pipelines where e.g. the VAE is still float32. + pipe = getattr(self.model_context, "pipe", None) + model = getattr(self.model_context, "model", None) + if pipe is not None and model is not None: + is_nextstep = hasattr(model, "config") and getattr(model.config, "model_type", None) == "nextstep" + if not is_nextstep: + pipe.to(model.dtype) + + def _get_block_forward_func(self, name: str): + """Diffusion models pass positional args; wrap the base forward func accordingly. + + The MRO guarantees that super() resolves to CalibCompressor._get_block_forward_func, + mirroring the old-arch pattern in compressors/diffusion/compressor.py. + """ + return wrap_block_forward_positional_to_kwargs(super()._get_block_forward_func(name)) + + def _should_stop_cache_forward(self, name: str) -> bool: + """Diffusion models must run all denoising steps to collect enough inputs. + + Mirrors old-arch DiffusionCompressor._should_stop_cache_forward which always + returns False so the pipeline never exits early after the first block hit. + Without this, CalibCompressor._should_stop_cache_forward would stop after the + first inference step, yielding only nsamples inputs instead of + nsamples * num_inference_steps. + """ + return False + + @torch.no_grad() + def calib(self, nsamples, bs): + """Perform diffusion-specific calibration for quantization. + + Override parent's calib method to use diffusion dataset loading logic. + The diffusion pipeline is read from ``self.model_context.pipe``. + """ + from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader + + pipe = self.model_context.pipe + if pipe is None: + raise ValueError( + "Diffusion pipeline not found in model_context. " "Ensure the model was loaded as a diffusion model." + ) + + logger.warning( + "Diffusion model will catch nsamples * num_inference_steps inputs, " + "you can reduce nsamples or num_inference_steps if OOM or take too much time." + ) + if isinstance(self.dataset, str): + dataset = self.dataset.replace(" ", "") + self.dataloader, self.batch_size, self.gradient_accumulate_steps = get_diffusion_dataloader( + dataset=dataset, + bs=self.batch_size, + seed=self.seed, + nsamples=self.nsamples, + gradient_accumulate_steps=self.gradient_accumulate_steps, + ) + else: + self.dataloader = self.dataset + total_cnt = 0 + + total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader)) + + if ( + hasattr(self.model, "hf_device_map") + and len(self.model.hf_device_map) > 1 + and pipe.device != self.model.device + and torch.device(self.model.device).type in ["cuda", "xpu"] + ): + logger.error( + "Diffusion model is activated sequential model offloading, it will crash during moving to GPU/XPU. " + "Please use model path for quantization or " + "move the pipeline object to GPU/XPU before passing them into API." + ) + exit(-1) + + if pipe.device != self.model.device: + pipe.to(self.model.device) + with tqdm(range(1, total + 1), desc="cache block inputs") as pbar: + for ids, prompts in self.dataloader: + if isinstance(prompts, tuple): + prompts = list(prompts) + try: + pipe( + prompts, + guidance_scale=self.guidance_scale, + num_inference_steps=self.num_inference_steps, + generator=( + None + if self.generator_seed is None + else torch.Generator(device=pipe.device).manual_seed(self.generator_seed) + ), + ) + except NotImplementedError: + pass + except Exception as error: + raise error + step = len(prompts) + total_cnt += step + pbar.update(step) + if total_cnt >= nsamples: + break + if total_cnt == 0: + logger.error( + f"no data has been cached, please provide more data with sequence length >={self.seqlen} in the " + f"dataset or decease the sequence length" + ) + exit(-1) + elif total_cnt < nsamples: + logger.warning( + f"Insufficient number of samples collected may affect the quantization. " + f"target samples count is {nsamples}, while valid samples count is {total_cnt}" + ) + if total_cnt < self.batch_size: + raise ValueError( + f"valid samples is less than batch_size({self.batch_size})," + " please adjust self.batch_size or seqlen." + ) + max_len = (total_cnt // self.batch_size) * self.batch_size + for k, v in self.inputs.items(): + for key in v: + if isinstance(v[key], list) and len(v[key]) == total_cnt: + self.inputs[k][key] = v[key][:max_len] + + # torch.cuda.empty_cache() + + def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + + Returns: + object: The compressed model object. + """ + if output_dir is None: + return super().save_quantized(output_dir, format=format, inplace=inplace, **kwargs) + + pipe = self.model_context.pipe + compressed_model = None + for name in pipe.components.keys(): + val = getattr(pipe, name) + sub_module_path = ( + os.path.join(output_dir, name) if os.path.basename(os.path.normpath(output_dir)) != name else output_dir + ) + if ( + hasattr(val, "config") + and hasattr(val.config, "_name_or_path") + and val.config._name_or_path == self.model.config._name_or_path + ): + compressed_model = super().save_quantized( + output_dir=sub_module_path if not self.compress_context.is_immediate_saving else output_dir, + format=format, + inplace=inplace, + **kwargs, + ) + elif val is not None and hasattr(val, "save_pretrained"): + val.save_pretrained(sub_module_path) + pipe.config.save_pretrained(output_dir) + return compressed_model diff --git a/auto_round/compressors_new/entry.py b/auto_round/compressors_new/entry.py new file mode 100644 index 000000000..7d519e007 --- /dev/null +++ b/auto_round/compressors_new/entry.py @@ -0,0 +1,525 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Any, Callable, Optional, Union + +import torch + +from auto_round.algorithms.alg_config import AlgConfig +from auto_round.algorithms.quantization.rtn.config import RTNConfig +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig +from auto_round.algorithms.transforms.rotation.config import RotationConfig as _NewArchRotationConfig +from auto_round.auto_scheme.gen_auto_scheme import AutoScheme +from auto_round.compressors_new.calib import CalibCompressor, CalibratedRTNCompressor +from auto_round.compressors_new.utils import check_need_act_calibration +from auto_round.compressors_new.zero_shot import ZeroShotCompressor +from auto_round.logger import logger +from auto_round.schemes import QuantizationScheme, _parse_scheme + + +def _preview_resolved_attrs(config, scheme=None) -> dict: + """Resolve scheme attributes without mutating config, for routing decisions. + + Called in ``AutoRound.__new__`` before the concrete compressor class is + chosen. ``SchemeMixin.resolve_scheme()`` will do the authoritative + resolution later; this is just a lightweight preview so routing logic + (``enable_imatrix``, ``needs_act_calib``, etc.) can use the correct values + even when the user specified only ``scheme=`` without explicit bit/dtype args. + + Returns: + dict: resolved attributes (may be empty if scheme cannot be previewed). + """ + if isinstance(scheme, AutoScheme): + # AutoScheme needs model info — cannot preview, rely on raw config attrs + return {} + scheme_attr_names = QuantizationScheme.get_attributes() + user_overrides = {k: getattr(config, k) for k in scheme_attr_names if getattr(config, k, None) is not None} + try: + _, _, final_attrs = _parse_scheme(scheme, user_overrides) + return final_attrs + except Exception: + return {} + + +def _eager_validate_scheme(config, scheme=None) -> None: + """Eagerly validate scheme/config constraints at construction time. + + Mirrors the old-arch ``_check_configs()`` call in ``BaseCompressor.__init__``. + Raises ``ValueError`` or ``NotImplementedError`` immediately if the scheme + contains config-only invalid combinations (e.g. tuple group_size with non-fp8 + weight dtype) so that callers get a fast failure rather than a deferred error + buried inside ``post_init()``. + + ``AutoScheme`` is skipped because it requires model information. + """ + if isinstance(scheme, AutoScheme): + return + + scheme_attr_names = QuantizationScheme.get_attributes() + user_overrides = {k: getattr(config, k) for k in scheme_attr_names if getattr(config, k, None) is not None} + try: + _, _, final_attrs = _parse_scheme(scheme, user_overrides) + except (ValueError, NotImplementedError): + raise + except Exception: + return # Other parse errors are deferred to post_init + + import copy + + temp_config = copy.copy(config) + for key, value in final_attrs.items(): + setattr(temp_config, key, value) + temp_config.check_config() # raises ValueError / NotImplementedError if invalid + + +# --------------------------------------------------------------------------- +# Compressor-class registry +# --------------------------------------------------------------------------- +# Maps (model_type, base_class_name) → combined class, created lazily. +_COMPRESSOR_REGISTRY: dict[tuple[str, str], type] = {} + + +def _get_compressor_class(model_type: str, base_cls: type) -> type: + """Return the compressor class for *base_cls* wired with the right model-type Mixin. + + For ``model_type == "llm"`` the bare *base_cls* is returned unchanged. + For ``"mllm"`` and ``"diffusion"`` the corresponding Mixin is prepended via + :func:`type` and the result is cached in ``_COMPRESSOR_REGISTRY`` so that + each ``(model_type, base_cls)`` pair is created at most once per process. + """ + if model_type == "llm": + return base_cls + key = (model_type, base_cls.__name__) + if key in _COMPRESSOR_REGISTRY: + return _COMPRESSOR_REGISTRY[key] + if model_type == "mllm": + from auto_round.compressors_new.mllm_mixin import MLLMMixin + + mixin = MLLMMixin + elif model_type == "diffusion": + from auto_round.compressors_new.diffusion_mixin import DiffusionMixin + + mixin = DiffusionMixin + else: + return base_cls + combined = type(f"{model_type.capitalize()}{base_cls.__name__}", (mixin, base_cls), {}) + _COMPRESSOR_REGISTRY[key] = combined + return combined + + +def is_weight_scheme(scheme): + if isinstance(scheme, str): + return scheme.upper().startswith("W") + if isinstance(scheme, dict): + return all(isinstance(s, str) and s.upper().startswith("W") for s in scheme.values()) + if isinstance(scheme, AutoScheme): + opts = scheme.options + if isinstance(opts, (list, tuple)): + return all(isinstance(s, str) and s.upper().startswith("W") for s in opts) + if isinstance(opts, str): + return opts.upper().startswith("W") + return False + + +def detect_model_type(model): + """Detect the type of model (LLM, MLLM, or Diffusion). + + Args: + model: Model instance or model path string + + Returns: + str: "mllm", "diffusion", or "llm" + """ + from auto_round.utils import is_diffusion_model, is_mllm_model + + # Check if it's a diffusion model first (more specific) + if is_diffusion_model(model): + return "diffusion" + + # Check if it's an MLLM + if is_mllm_model(model): + return "mllm" + + # Default to standard LLM + return "llm" + + +class AutoRound(object): + SKIP_ARGS = ("local_args", "kwargs", "cls", "alg_configs", "quant_config", "quant_configs") + + # Mapping from string alias to config class (and optional defaults override). + _CONFIG_ALIASES: dict[str, type] = { + "sign_round": SignRoundConfig, + "signround": SignRoundConfig, + "rtn": RTNConfig, + "hadamard": _NewArchRotationConfig, + } + + @classmethod + def _resolve_config(cls, config: Union[str, AlgConfig, list]) -> Union[AlgConfig, list[AlgConfig]]: + """Convert string alias(es) to the corresponding config instance(s) with default parameters.""" + if isinstance(config, str): + key = config.strip().lower() + if key not in cls._CONFIG_ALIASES: + raise ValueError(f"Unknown config alias '{config}'. " f"Supported: {list(cls._CONFIG_ALIASES.keys())}") + return cls._CONFIG_ALIASES[key]() + if isinstance(config, list): + return [cls._resolve_config(c) for c in config] + return config + + def __new__( + cls, + alg_configs: Union[str, AlgConfig, list[Union[str, AlgConfig]]], + model: Union[torch.nn.Module, str], + tokenizer=None, + platform="hf", + format=None, + scheme="W4A16", + low_gpu_mem_usage: bool = False, + device_map: Union[str, torch.device, int, dict] = 0, + enable_torch_compile: bool = False, + seed: int = 42, + low_cpu_mem_usage: bool = True, + layer_config=None, + nsamples: int = None, + seqlen: int = None, + **kwargs, + ): + from auto_round.algorithms.quantization.config import QuantizationConfig + + # Resolve string alias(es) to config instance(s) before routing. + alg_configs = cls._resolve_config(alg_configs) + + # Extract the single QuantizationConfig from a list; validate at most one exists. + if isinstance(alg_configs, list): + quant_configs = [c for c in alg_configs if isinstance(c, QuantizationConfig)] + if len(quant_configs) == 0: + raise ValueError("At least one QuantizationConfig (SignRoundConfig / RTNConfig) is required.") + if len(quant_configs) > 1: + raise ValueError( + f"Only one QuantizationConfig is allowed, but got {len(quant_configs)}: " + f"{[type(c).__name__ for c in quant_configs]}" + ) + quant_config = quant_configs[0] + else: + quant_config = alg_configs + + # Eagerly validate scheme constraints that do not require model info. + # This mirrors old-arch _check_configs() called at __init__ time so that + # callers get ValueError/NotImplementedError on construction, not deferred. + _eager_validate_scheme(quant_config, scheme) + + # using different compressor base on AlgConfigs + local_args = {k: v for k, v in locals().items() if k not in cls.SKIP_ARGS} + + # Detect model type to determine if we need special compressor + model_type = detect_model_type(model) + + # If the user explicitly passes processor/image_processor, treat as MLLM even if + # auto-detection missed it (mirrors the has_multimodal_assets check in autoround.py). + has_multimodal_assets = kwargs.get("processor") is not None or kwargs.get("image_processor") is not None + if has_multimodal_assets and model_type != "mllm": + model_type = "mllm" + + if isinstance(quant_config, SignRoundConfig): + return _get_compressor_class(model_type, CalibCompressor)(alg_configs, **local_args, **kwargs) + + elif isinstance(quant_config, RTNConfig): + enable_imatrix = False + disable_opt_rtn = getattr(quant_config, "disable_opt_rtn", False) + # If disable_opt_rtn was not explicitly set and scheme is W8A16/W8A8, + # auto-disable optimization to improve efficiency. + if getattr(quant_config, "orig_disable_opt_rtn", None) is None: + if isinstance(scheme, str) and scheme.upper() in ["W8A16", "W8A8"]: + logger.warning("`disable_opt_rtn` is turned on for W8A16/W8A8 quantization to improve efficiency.") + disable_opt_rtn = True + quant_config.disable_opt_rtn = True + if not disable_opt_rtn: + has_gguf_k = "gguf" in format.lower() and "_k" in format.lower() if format else False + if has_gguf_k: + enable_imatrix = True + else: + # Resolve scheme attrs for routing (config hasn't been through + # SchemeMixin yet; user may have specified only scheme="W4A16"). + _resolved = _preview_resolved_attrs(quant_config, scheme) + _sym = _resolved.get("sym", getattr(quant_config, "sym", None)) + _data_type = _resolved.get("data_type", getattr(quant_config, "data_type", "") or "") + _bits = _resolved.get("bits", getattr(quant_config, "bits", None)) + if _sym is not None and _sym is False: + enable_imatrix = False + elif _data_type == "int" and (_bits is None or _bits < 8): + enable_imatrix = True + elif is_weight_scheme(scheme): + enable_imatrix = True + else: + _resolved = {} + + _resolved = _resolved if not disable_opt_rtn else _preview_resolved_attrs(quant_config, scheme) + _act_bits = _resolved.get("act_bits", getattr(quant_config, "act_bits", None)) + _act_data_type = _resolved.get("act_data_type", getattr(quant_config, "act_data_type", None)) + _act_dynamic = _resolved.get("act_dynamic", getattr(quant_config, "act_dynamic", None)) + _is_act_quantize = _act_bits is not None and _act_bits <= 8 + needs_act_calib = _is_act_quantize and check_need_act_calibration( + _act_dynamic, + _act_data_type, + _act_bits if _act_bits is not None else 16, + static_kv_dtype=kwargs.get("static_kv_dtype"), + static_attention_dtype=kwargs.get("static_attention_dtype"), + ) + + # AutoScheme always requires calibration data for delta-loss based + # scheme selection, regardless of whether imatrix is needed. + from auto_round.auto_scheme.gen_auto_scheme import AutoScheme as _AutoScheme + + is_auto_scheme = isinstance(scheme, _AutoScheme) + + if enable_imatrix or needs_act_calib or is_auto_scheme: + quant_config._alg_cls = "OptimizedRTNQuantizer" + return _get_compressor_class(model_type, CalibratedRTNCompressor)(alg_configs, **local_args, **kwargs) + else: + quant_config._alg_cls = "RTNQuantizer" + return _get_compressor_class(model_type, ZeroShotCompressor)(alg_configs, **local_args, **kwargs) + + +class AutoRoundCompatible: + """AutoRoundCompatible wrapper class for backward compatibility. + + This class provides the same API as the old AutoRoundCompatible class but internally + uses the new AutoRound architecture with Mixin pattern. + + Args: + model: Model object or model name to load + tokenizer: Tokenizer for text processing + platform: Platform to download model ("hf" or "model_scope") + scheme: Quantization scheme (str, dict, or QuantizationScheme) + layer_config: Layer-wise quantization config + dataset: Calibration data + iters: Optimization iterations + seqlen: Calibration sequence length + nsamples: Number of calibration samples + batch_size: Calibration batch size + gradient_accumulate_steps: Gradient accumulation steps + low_gpu_mem_usage: Lower GPU memory mode + device_map: Device map for each module + enable_torch_compile: Enable torch.compile + seed: Random seed + low_cpu_mem_usage: Lower CPU memory mode + **kwargs: Additional arguments (bits, group_size, sym, etc.) + + Example: + >>> # Old API - still works + >>> from auto_round.compressors_new.entry import AutoRoundCompatible + >>> autoround = AutoRoundCompatible( + ... model="/models/opt-125m", + ... bits=4, + ... group_size=128, + ... iters=200, + ... ) + >>> quantized_model, layer_config = autoround.quantize() + """ + + SKIP_ARGS = ("local_args", "kwargs", "cls", "config") + + bits: int | None + group_size: int | None + sym: bool | None + data_type: str | None + act_bits: int | None + act_group_size: int | None + act_sym: bool | None + act_data_type: str | None + act_dynamic: bool | None + super_bits: int | None + super_group_size: int | None + + @staticmethod + def _pop_config_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract old-API config kwargs and split them by config type.""" + common_keys = ( + "ignore_layers", + "quant_lm_head", + "scale_dtype", + "super_bits", + "super_group_size", + "to_quant_block_names", + ) + auto_round_only_keys = ( + "nblocks", + "enable_alg_ext", + "lr_scheduler", + "not_use_best_mse", + "dynamic_max_gap", + "optimizer", + "enable_adam", + "momentum", + ) + common_kwargs = {} + auto_round_kwargs = {} + for key in common_keys: + if key in kwargs: + common_kwargs[key] = kwargs.pop(key) + for key in auto_round_only_keys: + if key in kwargs: + auto_round_kwargs[key] = kwargs.pop(key) + return common_kwargs, auto_round_kwargs + + def __new__( + cls, + model: Union[torch.nn.Module, str], + tokenizer=None, + platform: str = "hf", + scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16", + layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, + dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + iters: int = 200, + seqlen: int = 2048, + nsamples: int = 128, + batch_size: int = 8, + gradient_accumulate_steps: int = 1, + low_gpu_mem_usage: bool = False, + device_map: Union[str, torch.device, int, dict] = 0, + enable_torch_compile: bool = False, + seed: int = 42, + low_cpu_mem_usage: bool = True, + **kwargs, + ): + """Create AutoRoundCompatible instance using new AutoRound architecture. + + This method translates old AutoRoundCompatible API to new AutoRound API. + """ + from auto_round.utils import is_diffusion_model, is_mllm_model + + common_config_kwargs, auto_round_config_kwargs = cls._pop_config_kwargs(kwargs) + + # Extract quantization parameters from kwargs or use defaults + bits = kwargs.pop("bits", None) + group_size = kwargs.pop("group_size", None) + sym = kwargs.pop("sym", None) + data_type = kwargs.pop("data_type", None) + act_bits = kwargs.pop("act_bits", None) + act_group_size = kwargs.pop("act_group_size", None) + act_sym = kwargs.pop("act_sym", None) + act_data_type = kwargs.pop("act_data_type", None) + act_dynamic = kwargs.pop("act_dynamic", None) + + # Decide which algorithm to use + if iters == 0: + # RTN mode + disable_opt_rtn = kwargs.pop("disable_opt_rtn", None) + config = RTNConfig( + bits=bits, + group_size=group_size, + sym=sym, + data_type=data_type, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + act_data_type=act_data_type, + act_dynamic=act_dynamic, + disable_opt_rtn=disable_opt_rtn, + # for optRTN + batch_size=batch_size, + **common_config_kwargs, + ) + else: + # AutoRoundCompatible mode + lr = kwargs.pop("lr", None) + minmax_lr = kwargs.pop("minmax_lr", None) + enable_minmax_tuning = kwargs.pop("enable_minmax_tuning", True) + enable_norm_bias_tuning = kwargs.pop("enable_norm_bias_tuning", False) + enable_quanted_input = kwargs.pop("enable_quanted_input", True) + + config = SignRoundConfig( + iters=iters, + batch_size=batch_size, + gradient_accumulate_steps=gradient_accumulate_steps, + bits=bits, + group_size=group_size, + sym=sym, + data_type=data_type, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + act_data_type=act_data_type, + act_dynamic=act_dynamic, + lr=lr, + minmax_lr=minmax_lr, + enable_minmax_tuning=enable_minmax_tuning, + enable_norm_bias_tuning=enable_norm_bias_tuning, + enable_quanted_input=enable_quanted_input, + **common_config_kwargs, + **auto_round_config_kwargs, + ) + + # Determine output format if specified + format = kwargs.pop("format", None) + + # Extract rotation_config (old-API kwarg) and thread it into alg_configs. + # In old arch this was a standalone keyword arg; the new arch passes rotation + # transforms as part of the alg_configs list. All backends (auto / inplace / + # transform) are dispatched inside ``HadamardRotation.apply_to_model``. + _rotation_config_raw = kwargs.pop("rotation_config", None) + if _rotation_config_raw is not None: + if isinstance(_rotation_config_raw, _NewArchRotationConfig): + _rc = _rotation_config_raw + elif isinstance(_rotation_config_raw, dict): + _rc = _NewArchRotationConfig.model_validate(_rotation_config_raw) + else: + # str alias ("default", "random_hadamard", …) -> default config + _rc = _NewArchRotationConfig() + config = [config, _rc] + + # Extract MLLM-specific parameters + processor = kwargs.pop("processor", None) + image_processor = kwargs.pop("image_processor", None) + template = kwargs.pop("template", None) + extra_data_dir = kwargs.pop("extra_data_dir", None) + quant_nontext_module = kwargs.pop("quant_nontext_module", False) + + # Extract Diffusion-specific parameters + guidance_scale = kwargs.pop("guidance_scale", 7.5) + num_inference_steps = kwargs.pop("num_inference_steps", 50) + generator_seed = kwargs.pop("generator_seed", None) + + # Check model type for logging + if is_mllm_model(model, platform=platform): + logger.info("Using MLLM mode for multimodal model (new architecture).") + elif is_diffusion_model(model): + logger.info("Using Diffusion mode for diffusion model (new architecture).") + else: + logger.info("Using LLM mode (new architecture).") + + # Create AutoRound instance using new architecture + compressor = AutoRound( + alg_configs=config, + model=model, + tokenizer=tokenizer, + platform=platform, + format=format, + scheme=scheme, + dataset=dataset, + iters=iters, + low_gpu_mem_usage=low_gpu_mem_usage, + device_map=device_map, + enable_torch_compile=enable_torch_compile, + seed=seed, + low_cpu_mem_usage=low_cpu_mem_usage, + layer_config=layer_config, + nsamples=nsamples, + seqlen=seqlen, + # MLLM parameters + processor=processor, + image_processor=image_processor, + template=template, + extra_data_dir=extra_data_dir, + quant_nontext_module=quant_nontext_module, + # Diffusion parameters + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator_seed=generator_seed, + # Pass remaining kwargs + **kwargs, + ) + + return compressor diff --git a/auto_round/compressors_new/mllm_mixin.py b/auto_round/compressors_new/mllm_mixin.py new file mode 100644 index 000000000..ada1eacf5 --- /dev/null +++ b/auto_round/compressors_new/mllm_mixin.py @@ -0,0 +1,281 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from auto_round.logger import logger +from auto_round.utils import to_device + + +class MLLMMixin: + """MLLM-specific functionality mixin. + + This mixin adds MLLM-specific functionality to any compressor (CalibCompressor, + ZeroShotCompressor, ImatrixCompressor, etc). It handles multi-modal models + (vision-language models) that require special data loading and processing logic. + + Can be combined with: + - CalibCompressor (for AutoRound with calibration) + - ImatrixCompressor (for RTN with importance matrix) + - ZeroShotCompressor (for basic RTN) + + MLLM-specific parameters: + processor: Multi-modal processor override (normally loaded by ModelContext) + image_processor: Image processor override (e.g. for LLaVA) + template: Template name for processing different MLLMs + extra_data_dir: Path to extra data (images, audio, videos) + quant_nontext_module: Whether to quantize non-text modules + + Design note: + ``ModelContext._load_model()`` is responsible for loading the model and its + associated artifacts (processor, tokenizer, image_processor). This mixin + reads those artifacts from ``self.model_context`` during calibration. + If the caller passes explicit ``processor`` / ``image_processor`` overrides, + they are written into ``model_context`` after ``super().__init__()`` so that + ``model_context`` remains the single source of truth. + """ + + def __init__( + self, + *args, + processor=None, + image_processor=None, + template=None, + extra_data_dir=None, + quant_nontext_module=False, + **kwargs, + ): + self.template = template + self.extra_data_dir = extra_data_dir + self.quant_nontext_module = quant_nontext_module + self.template_obj = None + + # Pass quant_nontext_module to ModelContext so get_block_names can include vision blocks + kwargs.setdefault("quant_nontext_module", quant_nontext_module) + + # Mirror old arch: reset batch_size to 1 when quantizing non-text modules, + # because vision encoder blocks have non-standard hidden_states shapes that + # break batch_dim detection, and image collation fails with batch_size > 1. + if quant_nontext_module: + # batch_size may come from kwargs (placed there by AutoRoundCompatible local_args) + # or from the AlgConfig object in args[0] (the authoritative source for quantizer.batch_size). + # We must update both so that quantizer.batch_size is also reset to 1. + batch_size = kwargs.get("batch_size", None) + _alg_cfg = args[0] if args else None + if batch_size is None and _alg_cfg is not None: + cfgs = _alg_cfg if isinstance(_alg_cfg, list) else [_alg_cfg] + for cfg in cfgs: + if hasattr(cfg, "batch_size") and cfg.batch_size is not None: + batch_size = cfg.batch_size + break + if batch_size is not None and batch_size != 1: + grad_acc = kwargs.get("gradient_accumulate_steps", 1) + new_grad_acc = batch_size * grad_acc + kwargs["gradient_accumulate_steps"] = new_grad_acc + kwargs["batch_size"] = 1 + # Also patch the AlgConfig object so that BaseCompressor.quantize_config.batch_size == 1 + if _alg_cfg is not None: + cfgs = _alg_cfg if isinstance(_alg_cfg, list) else [_alg_cfg] + for cfg in cfgs: + if hasattr(cfg, "batch_size"): + cfg.batch_size = 1 + if hasattr(cfg, "gradient_accumulate_steps"): + cfg.gradient_accumulate_steps = new_grad_acc + logger.warning( + f"reset batch_size({batch_size}) to 1 and " + f"gradient_accumulate_steps to {new_grad_acc} " + f"because batch_size={batch_size} cannot be used for calibrating non-text modules." + ) + + # super().__init__() creates model_context, which eagerly loads the model and + # populates model_context.processor / image_processor / tokenizer. + super().__init__(*args, **kwargs) + + # Apply user-provided overrides into model_context (single source of truth). + if processor is not None: + self.model_context.processor = processor + if image_processor is not None: + self.model_context.image_processor = image_processor + + @torch.no_grad() + def calib(self, nsamples, bs): + """Perform MLLM-specific calibration for quantization. + + Override parent's calib method to use MLLM dataset loading logic. + All multimodal artifacts are read from ``self.model_context``. + """ + from transformers import PreTrainedModel + + from auto_round.compressors.mllm.dataset import get_mllm_dataloader + from auto_round.compressors.mllm.template import get_template + from auto_round.special_model_handler import MISTRAL_3_2_MODELS + + mc = self.model_context + processor = mc.processor + image_processor = mc.image_processor + tokenizer = mc.tokenizer + + # Handle template selection + if isinstance(mc.model, PreTrainedModel): + model_type = getattr(mc.model.config, "model_type", None) + if model_type == "llava" and self.template is None: + self.template = "default" + + if hasattr(mc.model, "name_or_path"): + name = mc.model.name_or_path + if any([m in name for m in MISTRAL_3_2_MODELS]): + self.template = "mistral3_2" + + template_name = self.template + if template_name is None and hasattr(mc.model.config, "model_type"): + template_name = mc.model.config.model_type + if template_name is None: + template_name = "default" + + self.template_obj = get_template( + template_name, + model=mc.model, + tokenizer=tokenizer, + processor=processor, + image_processor=image_processor, + use_rtn=getattr(self.quantize_config, "iters", None) == 0, + quiet=not self.quant_nontext_module, + ) + + logger.info(f"Using MLLM template: {template_name}") + + dataset = self.dataset.replace(" ", "") if isinstance(self.dataset, str) else self.dataset + if dataset is None: + dataset = self.template_obj.default_dataset + + if isinstance(self.dataset, str): + dataset = self.dataset.replace(" ", "") + # Mirror old arch __init__: switch text-only dataset to MLLM dataset when + # quant_nontext_module=True, as text datasets cannot calibrate vision modules. + from auto_round.calib_dataset import CALIB_DATASETS + + if self.quant_nontext_module and dataset in CALIB_DATASETS: + logger.warning( + "Text only dataset cannot be used for calibrating non-text modules," + " switching to liuhaotian/llava_conv_58k" + ) + dataset = "liuhaotian/llava_conv_58k" + ( + self.dataloader, + self.batch_size, + self.seqlen, + self.gradient_accumulate_steps, + ) = get_mllm_dataloader( + template=self.template_obj, + model=mc.model, + tokenizer=tokenizer, + processor=processor, + image_processor=image_processor, + dataset=dataset, + extra_data_dir=self.extra_data_dir, + seqlen=self.seqlen, + bs=bs, + seed=self.seed, + nsamples=nsamples, + quant_nontext_module=self.quant_nontext_module, + ) + else: + self.dataloader = self.dataset + + # Process data through the model for calibration + total_cnt = 0 + for data in self.dataloader: + if data is None: + continue + + try: + if isinstance(data, str): + # List-of-strings dataset: process through template → model inputs + processed = self.template_obj.processor.get_input( + text=data, images=None, max_length=self.seqlen, squeeze=False + ) + data_new = {k: to_device(v, mc.model.device) for k, v in processed.items()} + elif isinstance(data, dict) and "text" in data: + # FakeDataLoader-style {"text": ..., "image": ...}: process through template + text = data["text"] + if isinstance(text, dict): + text = [text] + input_text = self.template_obj._encode(text) + processed = self.template_obj.processor.get_input( + text=input_text, + images=data.get("image", None), + max_length=self.seqlen, + squeeze=False, + ) + data_new = {} + for key, value in processed.items(): + tensor_val = value if isinstance(value, torch.Tensor) else torch.as_tensor(value) + data_new[key] = to_device(tensor_val, mc.model.device) + elif isinstance(data, dict): + data_new = { + key: value.to(mc.model.device) if isinstance(value, torch.Tensor) else value + for key, value in data.items() + } + else: + data_new = data + + if isinstance(data_new, dict): + mc.model(**data_new) + else: + mc.model(data_new) + except NotImplementedError: + pass + except Exception as e: + logger.warning(f"Calibration forward pass failed: {e}") + continue + + total_cnt += bs + if total_cnt >= nsamples: + break + + if total_cnt == 0: + logger.error("no data has been cached, please provide more data") + exit(-1) + elif total_cnt < nsamples: + logger.warning(f"Insufficient number of samples: required {nsamples}, but only {total_cnt} were processed.") + + def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + + Returns: + object: The compressed model object. + """ + mc = self.model_context + processor = mc.processor + image_processor = mc.image_processor + tokenizer = mc.tokenizer + + if processor is not None and not hasattr(processor, "chat_template"): + processor.chat_template = None + compressed_model = super().save_quantized( + output_dir=output_dir, + format=format, + inplace=inplace, + processor=processor, + image_processor=image_processor, + quant_nontext_module=self.quant_nontext_module if hasattr(self, "quant_nontext_module") else False, + **kwargs, + ) + return compressed_model diff --git a/auto_round/compressors_new/shard_writer.py b/auto_round/compressors_new/shard_writer.py new file mode 100644 index 000000000..dbdd2cc86 --- /dev/null +++ b/auto_round/compressors_new/shard_writer.py @@ -0,0 +1,329 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections import OrderedDict + +import torch + +from auto_round.compressors_new.utils import _get_save_folder_name +from auto_round.context.compress import CompressContext +from auto_round.context.model import ModelContext +from auto_round.logger import logger +from auto_round.utils import get_lm_head_name, get_module + + +class ShardWriter: + """ + Handles shard-saving of model parameters to disk with memory management. + """ + + _instance = None + _initialized = False + + model = None + lm_head_name = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._data = {} + return cls._instance + + def __init__( + self, + model, + bits, + max_shard_size=None, + safe_serialization=True, + ): + if ShardWriter._initialized: + return + self.model = model + self.lm_head_name = get_lm_head_name(self.model) + total_params = sum(p.numel() for p in self.model.parameters()) + # Heuristic estimate of model size in GB used to choose a default max_shard_size: + # - total_params * rounder.bits -> total number of bits in all parameters + # - // 8 -> convert bits to bytes + # - // 1e9 -> approx convert bytes to GB (1e9 bytes ~= 1 GB) + # - final // 10 -> apply a safety margin so default shards are + # smaller than the full model; this intentionally + # underestimates size before clamping below. + max_split_num = 10 + model_size = int(total_params * bits // 1e9 // 8 + max_split_num - 1) / max_split_num + model_size = max(1, min(int(model_size), 5)) + + # Configuration + max_shard_size = max_shard_size or f"{model_size}GB" + self.max_shard_size = self._parse_size(max_shard_size) + self.safe_serialization = safe_serialization + + # Internal State + self.use_safetensors = self._check_safetensors() + self.shard_suffix = "safetensors" if self.use_safetensors else "bin" + self.current_shard_tensors = OrderedDict() + self.current_shard_size = 0 + self.shard_meta = [] # List of {tmp_file: str, params: list} + self.global_weight_map = {} + self.shard_counter = 0 + + # Persistent set of all parameter names already flushed to a shard file. + # Maintained incrementally in _flush_shard to avoid O(N^2) rebuilds in _add_tensor. + self._all_saved = set() + + # Stats + self.total_param_elems = 0 + self.total_param_size_bytes = 0 + self.skipped_meta_tensors = [] + + ShardWriter._initialized = True + + @property + def output_dir(self) -> str: + """Derive the output directory from the current CompressContext at access time. + + Reading from context rather than caching the path at construction time ensures + the ShardWriter always uses the final export directory even if + ``CompressContext.output_dir`` is updated after the ShardWriter was created + (e.g. by ``_get_export_dir()`` in ``quantize_and_save()``). + """ + compress_context = CompressContext.get_context() + formats = compress_context.formats + base_dir = _get_save_folder_name(formats[0]) + subfolder = getattr(self.model, "_autoround_pipeline_subfolder", None) + if subfolder: + base_dir = os.path.join(base_dir, subfolder) + return os.path.join(base_dir, "") + + @classmethod + def reset(cls): + """Reset the singleton state so the next instantiation creates a fresh ShardWriter.""" + cls._initialized = False + cls._instance = None + + @classmethod + def get_shard_writer(cls, *args, **kwargs): + """Return the current singleton instance, or None if not yet initialized. + + Callers that require a valid writer should guard the result with + ``if self.compress_context.is_immediate_saving`` before use. + """ + return cls._instance + + def _parse_size(self, size_str: str) -> int: + if isinstance(size_str, int): + return size_str + s = size_str.strip().upper() + units = {"GB": 1024**3, "MB": 1024**2, "KB": 1024, "B": 1} + for unit, mult in units.items(): + if s.endswith(unit): + return int(float(s[: -len(unit)]) * mult) + return int(s) + + def _check_safetensors(self) -> bool: + if self.safe_serialization: + try: + import safetensors.torch + + return True + except ImportError: + logger.warning("safetensors not installed; falling back to torch.save.") + return False + + def save_module(self, m: torch.nn.Module, name: str = None): + """Extracts and accumulates tensors from a module.""" + prefix = name if name is not None else getattr(m, "global_name", "model") + sd = m.state_dict() + + for k, v in sd.items(): + if not isinstance(v, torch.Tensor): + continue + param_name = f"{prefix}.{k}" + self._add_tensor(param_name, v) + + def _add_tensor(self, name: str, tensor: torch.Tensor): + if isinstance(tensor, torch.Tensor) and tensor.device.type == "meta": + self.skipped_meta_tensors.append(name) + return + + # Guard against duplicate saving of the same parameter + if name in self._all_saved or name in self.current_shard_tensors: + return + + t_size = tensor.nbytes + self.total_param_elems += tensor.numel() + self.total_param_size_bytes += t_size + tensor = tensor.detach().cpu() + # If single tensor exceeds limit, flush current, save it solo, then continue + if t_size > self.max_shard_size: + self._flush_shard() + self.current_shard_tensors[name] = tensor + self.current_shard_size = t_size + self._flush_shard() + # If adding exceeds limit, flush first + elif self.current_shard_size + t_size > self.max_shard_size and self.current_shard_size > 0: + self._flush_shard() + self.current_shard_tensors[name] = tensor + self.current_shard_size = t_size + else: + self.current_shard_tensors[name] = tensor + self.current_shard_size += t_size + + def _handle_tied_weights(self): + """ + Detects tied weights in the current shard and ensures they are only saved once. + This is done by tracking storage pointers of tensors and skipping duplicates. + """ + storage_map = set() + filtered_tensors = {} + + for name, tensor in self.current_shard_tensors.items(): + if not isinstance(tensor, torch.Tensor): + filtered_tensors[name] = tensor + continue + + ptr = tensor.untyped_storage().data_ptr() + tensor.storage_offset() * tensor.element_size() + if ptr not in storage_map: + storage_map.add(ptr) + filtered_tensors[name] = tensor + self.current_shard_tensors = filtered_tensors + + def _flush_shard(self): + if not self.current_shard_tensors: + return + + self.shard_counter += 1 + output_dir = self.output_dir + os.makedirs(output_dir, exist_ok=True) + tmp_name = f"model-shard-{self.shard_counter:05d}.{self.shard_suffix}" + tmp_path = os.path.join(output_dir, tmp_name) + self._handle_tied_weights() + + if self.use_safetensors: + from safetensors.torch import save_file + + # Ensure tensors are contiguous in-place to avoid duplicating them in a separate dict, + # which can increase peak RAM usage during saving. + for k, v in list(self.current_shard_tensors.items()): + if isinstance(v, torch.Tensor) and not v.is_contiguous(): + self.current_shard_tensors[k] = v.contiguous() + save_file(self.current_shard_tensors, tmp_path) + else: + torch.save(self.current_shard_tensors, tmp_path) + + saved_params = list(self.current_shard_tensors.keys()) + self.shard_meta.append({"tmp_file": tmp_name, "params": saved_params, "dir": output_dir}) + self._all_saved.update(saved_params) + + # Offload logic: move modules to meta device once all params are saved + self._offload_to_meta(saved_params) + + self.current_shard_tensors = OrderedDict() + self.current_shard_size = 0 + + def _offload_to_meta(self, saved_params): + """Attempts to move fully saved modules to the 'meta' device to free RAM.""" + for param_full_name in saved_params: + module_path = param_full_name.rsplit(".", 1)[0] + + module = get_module(self.model, module_path) + # Check if all parameters of this module are now in '_all_saved' + if ( + module is not None + and isinstance(module, torch.nn.Module) + and all(f"{module_path}.{k}" in self._all_saved for k in module.state_dict().keys()) + ): + module.to("meta") + + def finalize(self): + """Saves remaining weights, renames files, and writes the index JSON.""" + # 1. Capture remaining weights not yet saved + full_sd = self.model.state_dict() + tie_word_embeddings = False + if hasattr(self.model, "config") and hasattr(self.model.config, "tie_word_embeddings"): + tie_word_embeddings = self.model.config.tie_word_embeddings + + finalize_skipped_meta_tensors = [] + for pname, tensor in full_sd.items(): + if pname in self._all_saved: + continue + if tensor.device.type == "meta": + continue + layer_name = ".".join(pname.split(".")[:-1]) + if self.lm_head_name is not None and layer_name == self.lm_head_name and tie_word_embeddings: + lm_head_module = get_module(self.model, self.lm_head_name) + lm_head_module.to("meta") # Must to meta, otherwise model's saver will dump it again + continue + self._add_tensor(pname, tensor.detach().to("cpu")) + + self._flush_shard() + + total_skipped = len(self.skipped_meta_tensors) + len(finalize_skipped_meta_tensors) + if total_skipped > 0: + examples = (self.skipped_meta_tensors + finalize_skipped_meta_tensors)[:5] + + # 2. Rename temp files to HF standard and map weights + if self.shard_counter == 0: + logger.warning("No tensors saved.") + return + + output_dir = self.output_dir + for idx, meta in enumerate(self.shard_meta, start=1): + shard_dir = meta.get("dir", output_dir) + old_path = os.path.join(shard_dir, meta["tmp_file"]) + new_name = ( + f"model.{self.shard_suffix}" + if self.shard_counter == 1 + else f"model-{idx:05d}-of-{self.shard_counter:05d}.{self.shard_suffix}" + ) + new_path = os.path.join(shard_dir, new_name) + os.rename(old_path, new_path) + for p in meta["params"]: + self.global_weight_map[p] = new_name + + # 3. Write Index JSON + index_ext = "safetensors.index.json" if self.use_safetensors else "bin.index.json" + index_path = os.path.join(output_dir, f"model.{index_ext}") + + index_data = { + "metadata": { + "format": "safetensors" if self.use_safetensors else "pytorch", + "total_shards": self.shard_counter, + "total_parameters": int(self.total_param_elems), + "total_size": int(self.total_param_size_bytes), + }, + "weight_map": self.global_weight_map, + } + + if self.shard_counter > 1: + with open(index_path, "w", encoding="utf-8") as f: + json.dump(index_data, f, indent=2) + + logger.info(f"model has been saved to {self.output_dir}") + + @torch.no_grad() + def write(self, m: torch.nn.Module = None, name: str = None, is_finalize: bool = False): + if m is None and name is None and not is_finalize and not is_finalize: + raise ValueError("Must specify either name or m") + if m is None and name is not None: + m = get_module(self.model, name) + # Perform the save + if m is not None: + self.save_module(m, name) + + if is_finalize: + self.finalize() + ShardWriter._initialized = False + ShardWriter._instance = None diff --git a/auto_round/compressors_new/utils.py b/auto_round/compressors_new/utils.py new file mode 100644 index 000000000..fcebb45d7 --- /dev/null +++ b/auto_round/compressors_new/utils.py @@ -0,0 +1,1246 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +import os +import random +import re +import sys +from dataclasses import asdict, fields +from enum import Enum +from typing import TYPE_CHECKING, Callable, Union + +import torch +import transformers +from torch.amp import autocast + +from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, GGUF_CONFIG, GGUF_INNER_CONFIG, QK_K, ModelType +from auto_round.logger import logger +from auto_round.utils import check_to_quantized, get_layer_names_in_block, get_module, to_standard_regex + +if TYPE_CHECKING: + from auto_round.schemes import QuantizationScheme + + +class BackendDataType(str, Enum): + STANDARD_FP = "fp" + MX_FP = "mx_fp" + NV_FP = "nv_fp" + MX_INT = "mx_int" + + +def is_standard_fp(backend): + backend = backend.lower() + return BackendDataType.STANDARD_FP in backend and not is_mx_fp(backend) and not is_nv_fp(backend) + + +def is_mx_fp(backend): + backend = backend.lower() + return BackendDataType.MX_FP in backend + + +def is_mx_int(backend): + backend = backend.lower() + return BackendDataType.MX_INT in backend + + +def is_nv_fp(backend): + backend = backend.lower() + return BackendDataType.NV_FP in backend + + +def is_wint_woq(ar): + """Returns True for integer weight-only quantization with non-quantized activations (`act_bits >= 16`).""" + return "int" in ar.data_type and ar.act_bits >= 16 and ar.super_group_size is None + + +def is_wint_a16(ar): + """Backward-compatible alias for `is_wint_woq()`.""" + return is_wint_woq(ar) + + +def _is_weight_fp8_activation_static_fp8( + bit: int, group_size: int, sym: bool, data_type: str, act_dynamic: bool +) -> bool: + return bit == 8 and group_size == -1 and sym and data_type == "fp" and not act_dynamic + + +def is_wfp8afp8(ar): + if ( + ("fp8" in ar.act_data_type or ("fp" in ar.act_data_type and ar.act_bits == 8)) + and ("fp8" in ar.data_type or ("fp" in ar.data_type and ar.bits == 8)) + and is_standard_fp(ar.act_data_type) + and is_standard_fp(ar.data_type) + ): + return True + else: + return False + + +def is_wint8aint8(ar): + if ("int8" in ar.act_data_type or ("int" in ar.act_data_type and ar.act_bits == 8)) and ( + "int8" in ar.data_type or ("int" in ar.data_type and ar.bits == 8) + ): + return True + else: + return False + + +def is_static_wfp8afp8(ar_or_format: Union[str, Callable]) -> bool: + if isinstance(ar_or_format, str): + return "fp8_static" in ar_or_format.lower() + if ar_or_format.act_dynamic: + return False + if is_wfp8afp8(ar_or_format): + return True + return False + + +def is_dynamic_wint8aint8(ar_or_format: Union[str, Callable]) -> bool: + if isinstance(ar_or_format, str): + return "int8_w8a8" in ar_or_format.lower() + if not ar_or_format.act_dynamic: + return False + if is_wint8aint8(ar_or_format): + return True + return False + + +def is_dynamic_afp8(ar_or_format: Callable) -> bool: + return ar_or_format.act_dynamic and ar_or_format.act_data_type.startswith("fp") and ar_or_format.act_bits == 8 + + +def is_block_wfp8(ar_or_format: Callable) -> bool: + return ( + isinstance(ar_or_format.group_size, tuple) + and len(ar_or_format.group_size) == 2 + and ar_or_format.data_type.startswith("fp") + and ar_or_format.bits == 8 + ) + + +def block_forward( + block: torch.nn.Module, + input_ids: torch.Tensor, + input_others: dict, + amp: bool = False, + amp_dtype: torch.dtype = torch.float16, + device: torch.device = torch.device("cpu"), + output_return_id: int = 0, +) -> Union[torch.Tensor, dict]: + """Performs a forward pass through a block with the given inputs. + + Args: + block: The block to perform the forward pass on. + input_ids: The input IDs. + input_others: A dictionary containing other input data. + amp: A boolean indicating whether to use automatic mixed precision. + amp_dtype: The data type for automatic mixed precision. + device: The target device. + output_return_id: if the output has more than one tenor, return the specified idx tensor. + + Returns: + output: The output of the forward pass. + """ + from auto_round.utils.model import to_device + + if input_ids.device != device: + input_ids = to_device(input_ids, device) + input_others = to_device(input_others, device) + input_tuple = input_others.pop("positional_inputs", None) + if "alibi" in input_others.keys() and input_others["alibi"] is not None: + alibi = input_others["alibi"] + input_others["alibi"] = alibi.reshape(-1, alibi.shape[2], alibi.shape[3]) + if amp: + with autocast(device_type=str(device).split(":")[0], dtype=amp_dtype): # pragma: no cover + output = block(input_ids, *input_tuple, **input_others) + else: + output = block(input_ids, *input_tuple, **input_others) + if isinstance(output_return_id, int) and (isinstance(output, list) or isinstance(output, tuple)): + output = output[output_return_id] + return output + + +def check_skippable_keywords(key): + """ + Prints a reminder if a key is not stored during quantization fine-tuning. + """ + skippable_cache_keys = ("past_key_value",) + for cache_key in skippable_cache_keys: + if cache_key not in key: + return True + return False + + +def check_need_act_calibration( + is_act_dynamic: Union[bool, None], + act_data_type: Union[str, None] = None, + act_bits: Union[int, None] = 16, + static_kv_dtype: Union[str, None] = None, + static_attention_dtype: Union[str, None] = None, +) -> bool: + if static_kv_dtype is not None or static_attention_dtype is not None: + return True + if act_bits is None or act_bits > 8: + return False + # None is dynamic + if is_act_dynamic is not None and not is_act_dynamic: + return True + if act_data_type is not None and "static" in act_data_type: + return True + return False + + +def collect_best_params(block, cache_device="cpu"): + """Collect the best parameters from the block to the specified device.""" + params = {} + if hasattr(block, "orig_layer"): + for key in block.params.keys(): + params[key] = block.params[key].data.to(cache_device, copy=True) + else: + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + params[n] = {} + for key in m.params.keys(): + params[n][key] = m.params[key].data.to(cache_device, copy=True) + return params + + +def infer_bits_by_data_type(data_type: str): + """Infer bits by data_type + + Args: + data_type (str): data_type + + Returns: + int: bits inferred by data_type, None means cannot infer correct bits by data_type + """ + from auto_round.utils import SUPPORTED_DTYPES + + if data_type is None: + return 16 + for supported_dtype in SUPPORTED_DTYPES: + if data_type.startswith(supported_dtype) and len(data_type) > len(supported_dtype): + ##first check the following two bits + suc_2str = data_type[len(supported_dtype) : len(supported_dtype) + 2] + if str.isdigit(suc_2str): + return int(suc_2str) + if str.isdigit(data_type[len(supported_dtype)]): + return int(data_type[len(supported_dtype)]) + return None + + +def _get_safetensor_layer_names_not_in_model(model, all_module_names: list) -> list: + """Collect layer names from safetensor files that are not loaded into the model. + + Some tensors (e.g. MTP layers) exist in the original checkpoint but are not + instantiated by ``transformers``. This function discovers them so that regex + patterns in ``layer_config`` can still match them. + + Returns: + List of layer names (the path without the ``.weight`` suffix) for weight + tensors present in the safetensor files but absent from *all_module_names*. + """ + name_or_path = None + if hasattr(model, "config") and hasattr(model.config, "name_or_path"): + name_or_path = model.config.name_or_path + if not name_or_path: + return [] + + if not os.path.isdir(name_or_path): + try: + from auto_round.utils.model import download_hf_model + + name_or_path = download_hf_model(name_or_path) + except Exception as e: + logger.debug(f"Could not resolve source model path to check for missing tensors: {e}") + return [] + + try: + from safetensors import safe_open + except ImportError: + return [] + + # Build tensor-name list from the safetensors index or single file + source_index_file = os.path.join(name_or_path, "model.safetensors.index.json") + source_single_file = os.path.join(name_or_path, "model.safetensors") + + tensor_names: list = [] + if os.path.exists(source_index_file): + with open(source_index_file) as f: + src_index = json.load(f) + tensor_names = list(src_index["weight_map"].keys()) + elif os.path.exists(source_single_file): + with safe_open(source_single_file, framework="pt", device="cpu") as f: + tensor_names = list(f.keys()) + else: + return [] + + module_name_set = set(all_module_names) + extra_layer_names = [] + for tensor_name in tensor_names: + if not tensor_name.endswith(".weight"): + continue + layer_name = tensor_name[: -len(".weight")] + if layer_name not in module_name_set: + extra_layer_names.append(layer_name) + return extra_layer_names + + +def set_layer_config( + model: torch.nn.Module, + layer_config: dict[str, Union[str, dict, "QuantizationScheme"]], + default_scheme: Union[str, "QuantizationScheme"], + default_scale_dtype: torch.dtype | str, + supported_types: tuple, + inner_supported_types: tuple, + quant_block_list=None, + ignore_layers: str = "", + quant_lm_head: bool = False, + enable_gguf_official_mixed: bool = True, + is_mllm: bool = False, + fill_default_value=True, + gguf_format_name: str = None, +) -> tuple[dict, bool, dict]: + """ + Normalize, validate, and expand layer-specific quantization configs. + Returns (final_layer_config, has_quant_layer_outside_block) + """ + + from auto_round.schemes import QuantizationScheme, get_gguf_scheme, preset_name_to_scheme + from auto_round.utils.model import get_layer_names_in_block, get_lm_head_name, get_module, is_separate_lm_head + + # ---- helpers ------------------------------------------------- + def dispatch_layer_config(layer_config: dict[str, dict]) -> None: + """Assign scheme values as attributes to matched modules.""" + for layer_name, scheme in layer_config.items(): + module = get_module(model, layer_name) + if module is None: + # Layer exists in safetensor files but is not loaded into the model + # (e.g. MTP layers that transformers does not instantiate). Skip. + continue + for attr, value in scheme.items(): + setattr(module, attr, value) + + def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str) -> dict: + """Convert config entry into dict and validate keys.""" + if isinstance(item, str): + config = asdict(preset_name_to_scheme(item.upper())) + elif isinstance(item, QuantizationScheme): + config = asdict(item) + elif isinstance(item, dict): + # "in_blocks" is an internal bookkeeping key injected by LLM-Compressor; + # silently drop it before validation. + item = {k: v for k, v in item.items() if k != "in_blocks"} + invalid = set(item) - set(scheme_keys + ("fixed_by_user", "scale_dtype")) + if invalid: + raise ValueError( + f"Invalid keys {invalid} in layer_config for '{layer_name}'. " f"Allowed keys: {scheme_keys}" + ) + config = dict(item) + else: + raise TypeError( + f"Unsupported type for layer_config[{layer_name}]: {type(item)}. " + f"Expected str, dict, or QuantizationScheme." + ) + # Clean up + config = {k: v for k, v in config.items() if v is not None} + config["fixed_by_user"] = True + return config + + # ---- main logic ---------------------------------------------- + extra_scheme_keys = ("scale_dtype",) + scheme_keys = tuple(f.name for f in fields(QuantizationScheme)) + ("scale_dtype",) + layer_config = copy.deepcopy(layer_config) or {} + ignore_layer_patterns = set() + if ignore_layers: + ignore_layers = ignore_layers.replace(" ", "").split(",") + ignore_layers = [name + "." if name[-1].isdigit() else name for name in ignore_layers] + ignore_layer_patterns = set(ignore_layers) + + # 1. ignore_layers -> force 16 + for name in get_fp_layer_names(model, ignore_layers): + layer_config[name] = { + "bits": 16, + "act_bits": 16, + "data_type": "float", + "act_data_type": "float", + "fixed_by_user": True, + } + + # 2. normalize + layer_config = {k: normalize_item(v, k) for k, v in layer_config.items()} + + # 3. infer missing bits + for cfg in layer_config.values(): + if "data_type" in cfg and "bits" not in cfg: + if (b := infer_bits_by_data_type(cfg["data_type"])) is not None: + cfg["bits"] = b + if "act_data_type" in cfg and "act_bits" not in cfg: + if (b := infer_bits_by_data_type(cfg["act_data_type"])) is not None: + cfg["act_bits"] = b + + # 4. fill defaults + if isinstance(default_scheme, str): + default_dict = asdict(preset_name_to_scheme(default_scheme.upper())) + else: + default_dict = asdict(default_scheme) + default_dict["scale_dtype"] = default_scale_dtype + + # In AutoScheme with mixed gguf:q4_k_m, the super_group_size of gguf:q8_0 layer is None, + # which should not be filled by default q4km again + for cfg in layer_config.values(): + for key in scheme_keys: + if fill_default_value: + cfg.setdefault(key, copy.deepcopy(default_dict.get(key))) + else: + if key in extra_scheme_keys: + cfg.setdefault(key, copy.deepcopy(default_dict.get(key))) + else: + cfg.setdefault(key, None) + + # 5. collect supported modules + embedding_types = (torch.nn.Embedding,) + gguf_name = gguf_format_name or get_gguf_scheme(default_scheme) + if gguf_name: + if torch.nn.Embedding not in supported_types: + supported_types = (*supported_types, torch.nn.Embedding) + + # for some Embedding which type() is not torch.nn.Embedding + # for example: transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding + model_module_name = model.__class__.__module__ + module_cls = sys.modules[model_module_name] + for name in module_cls.__dict__: + if name.endswith("Embedding") and not name.endswith("RotaryEmbedding"): + embedding_types = (*embedding_types, getattr(module_cls, name)) + supported_types = (*supported_types, *embedding_types) + + all_supported_layer_names, embedding_layer_names = [], [] + all_module_names = [] + for n, m in model.named_modules(): + all_module_names.append(n) + # cleanup stale attributes + for key in scheme_keys: + if hasattr(m, key): + delattr(m, key) + if type(m) not in supported_types and m.__class__.__name__ not in inner_supported_types: + continue + all_supported_layer_names.append(n) + if isinstance(m, embedding_types) or m.__class__.__name__.endswith("Embedding"): + embedding_layer_names.append(n) + + # Also include layer names from safetensor files not loaded into the model + # (e.g. MTP layers that transformers does not instantiate). + safetensor_only_names = _get_safetensor_layer_names_not_in_model(model, all_module_names) + + # 6. expand regex configs + regex_config = {} + for name in list(layer_config.keys()): + if name in all_supported_layer_names: + continue + if name in all_module_names: + m = get_module(model, name) + if len(list(m.children())) == 0 and type(m) not in supported_types: + val = layer_config.pop(name) + if name in ignore_layer_patterns: + # Keep unsupported ignore_layers entries so export can serialize + # them into regex-based extra_config for loaders like vLLM INC. + regex_config[name] = val + else: + logger.warning( + f"'{name}' exists in the model but is not a supported quantization target " + f"in the current scheme, ignoring its setting in `layer_config`" + ) + continue + + regex = re.compile(to_standard_regex(name)) + matched = [ln for ln in all_supported_layer_names if regex.search(ln)] + safetensor_only_matched = [ln for ln in safetensor_only_names if regex.search(ln)] + # skip it for mtp layers not loaded in transformers + if not matched and not safetensor_only_matched: + # type(mlp.gate) is Qwen3VLMoeTextTopKRouter instead of Linear + logger.warning_once( + f"Layer name or regex '{name}' in layer_config does not match any supported layers. " + + "Please check for typos or update the regex pattern, ignore it for now" + ) + val = layer_config.pop(name) + regex_config[name] = val # keep regex config + for match in matched: + layer_config[match] = val + + # 7. lm_head + lm_head_name = get_lm_head_name(model) + tie_word_embeddings = False + if hasattr(model, "config") and hasattr(model.config, "tie_word_embeddings"): + tie_word_embeddings = model.config.tie_word_embeddings + + if lm_head_name in layer_config: + quant_lm_head = True + + if quant_lm_head and tie_word_embeddings and not gguf_name: + quant_lm_head = False + logger.warning( + "reset `quant_lm_head` to false as quantizing " "lm_head with tied weights has not been supported currently" + ) + + if lm_head_name not in layer_config and quant_lm_head: + layer_config[lm_head_name] = copy.deepcopy(default_dict) + + if not quant_lm_head and not gguf_name: + layer_config.pop(lm_head_name, None) + + # 8. enforce shape divisibility for int weight-only + if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16 and not gguf_name: + for n, m in model.named_modules(): + if type(m) in supported_types or m.__class__.__name__ in inner_supported_types: + if m.weight.shape[0] % 32 or m.weight.shape[1] % 32: + layer_config.setdefault(n, copy.deepcopy(default_dict)) + layer_config[n].update({"bits": 16, "data_type": "fp", "fixed_by_user": True}) + # logger.warning_once(f"{n} skipped quantization (shape not divisible by 32).") + # enforce shape divisibility for mxfp/nvfp + if (is_nv_fp(default_dict["data_type"]) or is_mx_fp(default_dict["data_type"])) and not gguf_name: + for n, m in model.named_modules(): + if type(m) in supported_types or m.__class__.__name__ in inner_supported_types: + if m.weight.shape[1] % default_dict["group_size"]: + layer_config.setdefault(n, copy.deepcopy(default_dict)) + layer_config[n].update( + {"bits": 16, "data_type": "fp", "act_bits": 16, "act_data_type": "fp", "fixed_by_user": True} + ) + logger.warning_once( + f"{n} skipped quantization (shape not divisible by {default_dict['group_size']})." + ) + + # 9. block layers: mark as in_blocks=True + for name in get_layer_names_in_block(model, supported_types, quant_block_list, inner_supported_types): + if name not in layer_config: + layer_config[name] = copy.deepcopy(default_dict) + layer_config[name]["fixed_by_user"] = False + layer_config[name]["in_blocks"] = True + + # ---- restore: ensure missing in_blocks are set to False and compute flag ---- + has_qlayer_outside_block = False + for cfg in layer_config.values(): + if "in_blocks" not in cfg: + cfg["in_blocks"] = False + # mark layer outside block + if not cfg["in_blocks"] and check_to_quantized(cfg): + has_qlayer_outside_block = True + + # 10. GGUF handling + if not gguf_name: + dispatch_layer_config(layer_config) + return layer_config, has_qlayer_outside_block, regex_config + + # embed + lm_head defaults for gguf + tie_word_embeddings &= not is_separate_lm_head(model) + if lm_head_name not in layer_config and not tie_word_embeddings: + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["lm_head"]] + cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype} + layer_config[lm_head_name] = cfg + has_qlayer_outside_block = True + for emd_name in embedding_layer_names: + if emd_name in layer_config: + continue + if not tie_word_embeddings: + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["embedding"]] + else: + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["lm_head"]] + cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype} + layer_config[emd_name] = cfg + + if enable_gguf_official_mixed: + model_type = ModelType.MMPROJ if is_mllm else ModelType.TEXT + layer_config, _ = get_layer_config_by_gguf_format(layer_config, gguf_name.lower(), model, model_type) + + dispatch_layer_config(layer_config) + return layer_config, has_qlayer_outside_block, regex_config + + +def _use_more_bits(i_layer: int, n_layer: int): + return (i_layer < n_layer // 8) or (i_layer >= 7 * n_layer // 8) or ((i_layer - n_layer // 8) % 3 == 2) + + +def _search_gguf_type(gguf_type): + if gguf_type in GGUF_INNER_CONFIG: + return gguf_type + pattern = re.compile("gguf:q([0-9]{1,})_[01k]") + bits = re.search(pattern, gguf_type) + if not bits: + raise KeyError(f"{gguf_type} is not a correct gguf type, please check") + + for suffix in ["_k", "_0", "_1"]: + if gguf_type.endswith(suffix): + continue + if (tmp_type := re.sub("_[01k]", suffix, gguf_type)) in GGUF_INNER_CONFIG: + return tmp_type + return None + + +def gguf_type_fallback(gguf_type: str) -> str: + gguf_type = gguf_type.lower() + if gguf_type in ("gguf:q2_k", "gguf:q3_k", "gguf:q4_k"): + gguf_type = "gguf:q5_0" + elif gguf_type == "gguf:q5_k": + gguf_type = "gguf:q5_0" + elif gguf_type == "gguf:q6_k": + gguf_type = "gguf:q8_0" + return gguf_type + + +def get_gguf_qtype_by_layer_config(layer_config): + import gguf # pylint: disable=E0401 + + if layer_config["bits"] >= 16: + return None + bits = layer_config["bits"] + super_bits = layer_config.get("super_bits", None) + sym = layer_config["sym"] + group_size = layer_config.get("group_size", None) + super_group_size = layer_config.get("super_group_size", None) + if bits == 2 and super_bits == 4 and not sym and group_size == 16 and super_group_size == 16: + return gguf.GGMLQuantizationType.Q2_K + if bits == 3 and super_bits == 6 and sym and group_size == 16 and super_group_size == 16: + return gguf.GGMLQuantizationType.Q3_K + if bits == 4: + if super_bits is not None and super_bits == 6 and not sym and group_size == 32 and super_group_size == 8: + return gguf.GGMLQuantizationType.Q4_K + if super_bits is None and sym and group_size == 32: + return gguf.GGMLQuantizationType.Q4_0 + if super_bits is None and not sym and group_size == 32: + return gguf.GGMLQuantizationType.Q4_1 + if bits == 5: + if super_bits == 6 and not sym and group_size == 32 and super_group_size == 8: + return gguf.GGMLQuantizationType.Q5_K + if super_bits is None and sym and group_size == 32: + return gguf.GGMLQuantizationType.Q5_0 + if super_bits is None and not sym and group_size == 32: + return gguf.GGMLQuantizationType.Q5_1 + if bits == 6 and super_bits == 8 and group_size == 16 and super_group_size == 16: + return gguf.GGMLQuantizationType.Q6_K + if bits == 8 and sym and group_size == 32: + return gguf.GGMLQuantizationType.Q8_0 + raise ValueError("Unknown layer config") + + +def _get_digital_in_layer_name(layer_name): + pattern = re.compile(r"([a-zA-Z]+\.){1,}(\d+)") + res = re.search(pattern, layer_name) + if res: + return int(res[2]) + else: + return None + + +def _gguf_type_fallback(gguf_type: str) -> str: + gguf_type = gguf_type.lower() + if gguf_type in ("gguf:q2_k", "gguf:q3_k", "gguf:q4_k"): + gguf_type = "gguf:q5_0" + elif gguf_type == "gguf:q5_k": + gguf_type = "gguf:q5_0" + elif gguf_type == "gguf:q6_k": + gguf_type = "gguf:q8_0" + return gguf_type + + +##https://github.com/ggml-org/llama.cpp/blob/9e31bec4fd53634c9e5b04650488a09a055f5dab/src/llama-quant.cpp#L129 +def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model, model_type=ModelType.TEXT): + # # TODO: support for other format later + # target_gguf_format = next((fmt for fmt in gguf_format if fmt != "fake"), None) + + import gguf # pylint: disable=E0401 + + from auto_round.schemes import QuantizationScheme, get_gguf_scheme + from auto_round.utils.common import MM_KEYS, LazyImport + from auto_round.utils.model import get_lm_head_name, get_module + + # from auto_round.export.export_to_gguf.convert import ModelBase, get_model_architecture + convert_hf_to_gguf = LazyImport("auto_round.export.export_to_gguf.convert_hf_to_gguf") + + try: + model_architecture = convert_hf_to_gguf.get_model_architecture( + hparams=model.config.to_dict(), model_type=model_type + ) + except AttributeError as e: + raise ImportError( + "Please use the latest gguf-py, you can use the following command to install it:\n" + "git clone https://github.com/ggml-org/llama.cpp.git && cd llama.cpp/gguf-py" + " && pip install . sentencepiece" + ) + try: + if model_type != ModelType.TEXT: + model_class_vision = convert_hf_to_gguf.ModelBase.from_model_architecture( + model_architecture, model_type=model_type + ) + model_class = convert_hf_to_gguf.ModelBase.from_model_architecture( + model_architecture, model_type=ModelType.TEXT + ) + + except NotImplementedError: + return layer_config, {} + + n_layer = None + if model_type != ModelType.TEXT: + n_layer_vision = None + for name in ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]: + if hasattr(model.config, name): + n_layer = getattr(model.config, name) + if model_type != ModelType.TEXT: + if n_layer is not None and hasattr(model.config, "text_config"): + if hasattr(getattr(model.config, "text_config"), name): + n_layer = getattr(getattr(model.config, "text_config"), name) + for config_name in ["vision_config", "vision_encoder"]: + if hasattr(model.config, config_name): + if hasattr(getattr(model.config, config_name), name): + n_layer_vision = getattr(getattr(model.config, config_name), name) + break + if n_layer and n_layer_vision: + break + + if n_layer is None: + return layer_config, {} + + tensor_map = gguf.get_tensor_name_map(model_class.model_arch, n_layer) + if model_type != ModelType.TEXT: + tensor_map_vision = gguf.get_tensor_name_map(model_class_vision.model_arch, n_layer_vision) + + def _set_config(config, target_config): + for k, v in target_config.items(): + if isinstance(config, dict): + config[k] = v + else: + setattr(config, k, v) + return config + + gguf_format_config = {} + lm_head_name = get_lm_head_name(model) + inner_gguf_format = GGUF_CONFIG[target_gguf_format]["mostly"] + # ggml_type = getattr(gguf.GGMLQuantizationType,inner_gguf_format.split(":")[-1].upper()) + block_size = GGML_QUANT_SIZES[inner_gguf_format.split(":")[-1].lower()][0] + tie_word_embeddings = True + if hasattr(model, "config") and hasattr(model.config, "tie_word_embeddings"): + tie_word_embeddings = model.config.tie_word_embeddings + + n_gqa = 1 + if ( + hasattr(model, "config") + and hasattr(model.config, "num_attention_heads") + and hasattr(model.config, "num_key_value_heads") + ): + n_gqa = model.config.num_attention_heads // model.config.num_key_value_heads + n_expert = 0 + for name in ["num_experts", "num_local_experts", "n_routed_experts"]: + if hasattr(model.config, name): + n_expert = getattr(model.config, name) + + i_attention_wv = 0 + i_ffn_down = 0 + layer_config_copy = copy.deepcopy(layer_config) + base_target_bits = None + if inner_gguf_format.startswith("gguf:q") and len(inner_gguf_format) >= 7 and (inner_gguf_format[6]).isdigit(): + base_target_bits = int(inner_gguf_format[6]) + + for layer_name, config in layer_config_copy.items(): + if not check_to_quantized(config): + continue + # Reset target_bits each iteration to prevent lm_head/embedding settings + # from bleeding into subsequent block layers and bypassing their special logic. + target_bits = base_target_bits + new_type = GGUF_CONFIG[target_gguf_format]["mostly"] + layer = get_module(model, layer_name) + if type(layer) == transformers.pytorch_utils.Conv1D: + input_features = layer.weight.shape[0] + else: + input_features = layer.weight.shape[-1] + i_layer = _get_digital_in_layer_name(layer_name) + + if lm_head_name is not None and layer_name == lm_head_name: + target_bits = int(re.search("gguf:q([0-9]{1,})_[01k]", GGUF_CONFIG[target_gguf_format]["lm_head"]).group(1)) + if isinstance(layer, torch.nn.Embedding): + target_bits = int( + re.search("gguf:q([0-9]{1,})_[01k]", GGUF_CONFIG[target_gguf_format]["embedding"]).group(1) + ) + + if model_type != ModelType.TEXT and any([key in layer_name for key in MM_KEYS]): + gguf_name = tensor_map_vision.get_name(layer_name) + if gguf_name is None: + for key in MM_KEYS: + gguf_name = tensor_map_vision.get_name(layer_name.replace(f".{key}", "")) + if gguf_name is not None: + break + else: + gguf_name = tensor_map.get_name(layer_name) + if gguf_name is None: + gguf_name = tensor_map.get_name(layer_name.replace(".language_model", "")) + bits_index = 6 + if config.get("fixed_by_user", False): + if "bits" not in config: + logger.warning( + f"Setting layer_config requires providing bits, {layer_name} has not bits," + f" using bits={target_bits} instead." + ) + new_type = new_type[:bits_index] + target_bits + new_type[bits_index + 1 :] + else: + config_tmp = config.copy() + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for key in config.keys(): + if key not in scheme_keys: + config_tmp.pop(key, None) + matched_scheme = get_gguf_scheme(QuantizationScheme.from_dict(config_tmp)) # check matched + if not matched_scheme: + if config.get("super_group_size", None) is not None or config.get("super_bits", None) is not None: + new_type = new_type[:bits_index] + str(config["bits"]) + "_k" + if new_type not in GGUF_INNER_CONFIG: + prefix_idx = 0 if config.get("sym", True) else 1 + new_type = new_type[:bits_index] + str(config["bits"]) + f"_{prefix_idx}" + if new_type not in GGUF_INNER_CONFIG: + new_type = new_type[:bits_index] + str(config["bits"]) + f"_{1-prefix_idx}" + if new_type not in GGUF_INNER_CONFIG: + raise ValueError( + f"the setting in layer_config {layer_name} " + f"could not match any supported gguf format, please have a check." + ) + + new_type = new_type[:bits_index] + str(config["bits"]) + new_type[bits_index + 1 :] + new_type = _search_gguf_type(new_type) + if new_type is None: + raise ValueError(f"invalid bit setting for {layer_name}") + elif target_bits is not None and "bits" in config and config["bits"] != target_bits: + new_type = new_type[:bits_index] + str(config["bits"]) + new_type[bits_index + 1 :] + new_type = _search_gguf_type(new_type) + if new_type is None: + raise ValueError(f"invalid bit setting for {layer_name}") + elif lm_head_name is not None and layer_name == lm_head_name and not tie_word_embeddings: + if gguf.MODEL_ARCH.FALCON == model_class.model_arch or input_features % block_size != 0: + new_type = "gguf:q8_0" + elif "lm_head" in GGUF_CONFIG[target_gguf_format]: + new_type = GGUF_CONFIG[target_gguf_format]["lm_head"] + elif new_type != "gguf:q8_0": + new_type = "gguf:q6_k" + elif lm_head_name is not None and layer_name == lm_head_name and tie_word_embeddings: + # new_type = GGUF_CONFIG[target_gguf_format]["lm_head"] + continue + elif isinstance(layer, torch.nn.Embedding): + if "embedding" in GGUF_CONFIG[target_gguf_format]: + new_type = GGUF_CONFIG[target_gguf_format]["embedding"] + elif gguf_name is None: + pass + # attn_v + elif "attn_v" in gguf_name: + if target_gguf_format == "gguf:q2_k": + new_type = "gguf:q4_k" if n_gqa >= 4 else "gguf:q3_k" + elif target_gguf_format == "gguf:q2_k_s" and n_gqa >= 4: + new_type = "gguf:q4_k" + elif target_gguf_format == "gguf:q3_k_m": + new_type = "gguf:q5_k" if i_attention_wv < 2 else "gguf:q4_k" + elif target_gguf_format == "gguf:q3_k_l": + new_type = "gguf:q5_k" + elif (target_gguf_format == "gguf:q4_k_m" or target_gguf_format == "gguf:q5_k_m") and _use_more_bits( + i_layer, n_layer + ): + new_type = "gguf:q6_k" + elif target_gguf_format == "gguf:q4_k_s" and i_attention_wv < 4: + new_type = "gguf:q5_k" + ##TODO check which models are be grouped into to LLM_TYPE_70B + # if (qs.model.type == LLM_TYPE_70B) { + # // In the 70B model we have 8 heads sharing the same attn_v weights. + # As a result, the attn_v.weight tensor is + # // 8x smaller compared to attn_q.weight.Hence, we can get a nice boost in quantization accuracy with + # // nearly negligible increase in model size by quantizing this tensor with more bits: + # if + # (new_type == GGML_TYPE_Q3_K | | new_type == GGML_TYPE_Q4_K) + # new_type = GGML_TYPE_Q5_K; + # } + if n_expert == 8: + new_type = "gguf:q8_k" + i_attention_wv += 1 + + elif "attn_k" in gguf_name: + if n_expert == 8: + new_type = "gguf:q8_0" + # ffn_down + elif "ffn_down" in gguf_name: + if target_gguf_format == "gguf:q2_k": + new_type = "gguf:q3_k" + elif target_gguf_format == "gguf:q2_k_s": + if i_layer < n_layer / 8: + new_type = "gguf:q4_k" + elif target_gguf_format == "gguf:q3_k_m": + if i_layer < n_layer / 16: + new_type = "gguf:q5_k" + elif gguf.MODEL_ARCH.FALCON == model_class.model_arch or _use_more_bits(i_layer, n_layer): + new_type = "gguf:q4_k" + else: + new_type = "gguf:q3_k" + elif target_gguf_format == "gguf:q3_k_l": + if gguf.MODEL_ARCH.FALCON == model_class.model_arch: + new_type = "gguf:q4_k" + else: + new_type = "gguf:q5_k" + elif target_gguf_format == "gguf:q4_k_m": + if gguf.MODEL_ARCH.FALCON == model_class.model_arch: + if i_layer < n_layer // 16: + new_type = "gguf:q6_k" + elif _use_more_bits(i_layer, n_layer): + new_type = "gguf:q5_k" + else: + new_type = "gguf:q4_k" + else: + if _use_more_bits(i_layer, n_layer): + new_type = "gguf:q6_k" + elif target_gguf_format == "gguf:q5_k_m" and _use_more_bits(i_layer, n_layer): + new_type = "gguf:q6_k" + elif ( + target_gguf_format == "gguf:q4_k_s" + and model_class.model_arch != gguf.MODEL_ARCH.FALCON + and i_layer < n_layer / 8 + ): + new_type = "gguf:q5_k" + elif (target_gguf_format == "gguf:q4_0" or target_gguf_format == "gguf:q5_0") and i_layer < n_layer / 8: + if target_gguf_format == "gguf:q4_0": + new_type = "gguf:q4_1" + else: + new_type = "gguf:q5_1" + i_ffn_down += 1 + + # attn_output + elif "attn_output" in gguf_name: + if gguf.MODEL_ARCH.FALCON != model_class.model_arch: + if n_expert == 8: + if target_gguf_format in ( + "gguf:q2_k", + "gguf:q3_k_s", + "gguf:q3_k_m", + "gguf:q4_k_s", + "gguf:q4_k_m", + "gguf:q5_k", + ): + new_type = "gguf:q5_k" + elif target_gguf_format == "gguf:q2_k": + new_type = "gguf:q3_k" + elif target_gguf_format == "gguf:q3_k_m": + new_type = "gguf:q4_k" + elif target_gguf_format == "gguf:q3_k_l": + new_type = "gguf:q5_k" + else: + if target_gguf_format == "gguf:q3_k_l": + new_type = "gguf:q4_k" + # attn_qkv + elif "attn_qkv" in gguf_name: + if target_gguf_format in ("gguf:q3_k_m", "gguf:q3_k_l"): + new_type = "gguf:q4_k" + elif target_gguf_format == "gguf:q4_k_m": + new_type = "gguf:q5_k" + elif target_gguf_format == "gguf:q5_k_m": + new_type = "gguf:q5_k" + new_block_size = GGML_QUANT_SIZES[new_type.split(":")[-1].lower()][0] + if input_features % new_block_size != 0: + new_type = _gguf_type_fallback(new_type) + new_block_size = GGML_QUANT_SIZES[new_type.split(":")[-1].lower()][0] + if input_features % new_block_size != 0: + new_type = "gguf:bf16" + logger.warning( + f"fallback {layer_name} to {new_type}, " + f"because input_features({input_features}) % block_size({block_size}) != 0" + ) + # for deepseek v2 + if layer_name.endswith("kv_b_proj") and new_type.endswith("_k") and "Deepseek" in model.config.architectures[0]: + fallback = False + + # calc if need fallback + qk_nope_head_dim = model.config.qk_nope_head_dim + kv_b_shape = get_module(model, layer_name).weight.shape + + if ( + qk_nope_head_dim < QK_K + or qk_nope_head_dim % QK_K != 0 + or kv_b_shape[-1] < QK_K + or kv_b_shape[-1] % QK_K != 0 + ): + fallback = True + if fallback: + tmp_type = _gguf_type_fallback(new_type) + logger.warning_once( + f"self_attn.kv_b_proj does not support the use of {new_type}, replace it with {tmp_type}" + ) + new_type = tmp_type + + target_config = GGUF_INNER_CONFIG[new_type] + + _set_config(layer_config[layer_name], target_config) + _set_config(layer, target_config) + gguf_format_config[layer_name] = new_type + + return layer_config, gguf_format_config + + +def get_fp_layer_names(model: torch.nn.Module, ignore_layers: str): + """Identifies and returns layers in the model to exclude from quantization. + + This function processes a comma-separated list of fully precision (FP) layers, + matches them to the names of layers in the model, and returns a list of such + layers to exclude from quantization. + + Args: + model (torch.nn.Module): The model whose layers will be inspected. + ignore_layers (str): A comma-separated string of layer names to be excluded + from quantization. Whitespace is ignored in this string. + + Returns: + list: A list of layer names that match the specified FP layers or are + subcomponents of those layers. + """ + from auto_round.utils import SUPPORTED_LAYER_TYPES + + if not ignore_layers: + return [] + + all_layer_names = [] + for n, m in model.named_modules(): + if type(m) in SUPPORTED_LAYER_TYPES: + all_layer_names.append(n) + not_to_quantized_layers = [] + + for fp_layer in ignore_layers: + if fp_layer == "": + continue + if fp_layer in all_layer_names: + not_to_quantized_layers.append(fp_layer) + continue + for name in all_layer_names: + if fp_layer in name: + not_to_quantized_layers.append(name) + not_to_quantized_layers.extend(ignore_layers) # keep regex name for later use + logger.trace(f"not_to_quantized_layers: {not_to_quantized_layers}") + return not_to_quantized_layers + + +def get_shared_keys(model): + """ + Retrieves shared keys from the model's state dictionary. + + Args: + model (torch.nn.Module): The model to retrieve shared keys from. + + Returns: + tuple: tuple of shared keys. + """ + from auto_round.special_model_handler import SPECIAL_SHARED_CACHE_KEYS + from auto_round.utils import SHARED_CACHE_KEYS + + shared_keys = SHARED_CACHE_KEYS + shared_keys += SPECIAL_SHARED_CACHE_KEYS.get(model.__class__.__name__, ()) + return shared_keys + + +def init_cache(positional_inputs, inputs): + """ + Initializes special model inputs by adding positional inputs if missing. + + Args: + positional_inputs (list): List of positional inputs to add to inputs. + inputs (dict): Dictionary of model inputs. + + Modifies: + inputs (dict): Adds "positional_inputs" key if not present. + """ + from auto_round.utils.model import to_device + + if "positional_inputs" not in inputs: # for chatglm Series + inputs["positional_inputs"] = [] + for idx, item in enumerate(positional_inputs): + inputs["positional_inputs"] = to_device(positional_inputs) + + +def reset_params(inputs): + """ + Resets specific input parameters to avoid saving the key-value cache during fine-tuning. + + Args: + inputs (dict): Dictionary of model inputs. + + Modifies: + inputs (dict): Sets "use_cache" to False if the key is present. + """ + if "use_cache" in inputs.keys(): # Not storing kv cache + inputs["use_cache"] = False + + +class IndexSampler: + """A cyclic sampler that returns shuffled index batches. + + This sampler maintains internal state so that each call to `next_batch()` + continues from where it left off. When the remaining number of samples is + less than `batch_size`, the sampler reshuffles all indices and starts from + the beginning, discarding the last incomplete batch. + + Attributes: + nsamples (int): Total number of samples. + batch_size (int): Number of indices to return in each batch. + index (int): Current position in the index list. + indices (List[int]): Shuffled list of indices. + """ + + def __init__(self, nsamples: int, batch_size: int) -> None: + """Initializes the sampler. + + Args: + nsamples (int): Total number of samples (must be >= batch_size). + batch_size (int): Number of indices per batch. + + Raises: + ValueError: If batch_size is not in the range (0, nsamples]. + """ + if batch_size <= 0 or batch_size > nsamples: + raise ValueError("batch_size must be > 0 and <= nsamples") + + self.nsamples: int = nsamples + self.batch_size: int = batch_size + self.index: int = 0 + + self.indices: list[int] = list(range(nsamples)) + random.shuffle(self.indices) + + def next_batch(self) -> list[int]: + """Returns the next batch of shuffled indices. + + If the remaining indices are fewer than `batch_size`, the sampler + reshuffles the entire list and starts from the beginning. + + Returns: + list[int]: A list of size `batch_size` containing sample indices. + """ + if self.index + self.batch_size > self.nsamples: + random.shuffle(self.indices) + self.index = 0 + + batch = self.indices[self.index : self.index + self.batch_size] + self.index += self.batch_size + return batch + + +def _get_quantized_layer_names_outside_blocks(model, layer_config, supported_types, quant_block_list) -> list: + """Gets the names of quantized layers outside blocks in the model. + + Returns: + list: List of layer names outside blocks. + """ + if layer_config is None or len(layer_config) == 0: + return [] + + layer_names = [] + all_layers_in_block = get_layer_names_in_block(model, supported_types, quant_block_list) + + for key in layer_config.keys(): + if key in all_layers_in_block: + continue + layer = get_module(model, key) + if layer is None: + logger.error(f"could not find layer {key} in the model, exit...") + exit(-1) + if type(layer) in supported_types and check_to_quantized(layer_config[key]): + layer_names.append(key) + + return layer_names + + +def _get_diffusion_save_folder_name(format) -> str: + """Generates the save folder name based on the provided format string. + + If there are multiple formats to handle, the function creates a subfolder + named after the format string with special characters replaced. If there's + only one format, it returns the original output directory directly. + + Args: + format_str (str): The format identifier (e.g., 'gguf:q2_k_s'). + + Returns: + str: The path to the folder where results should be saved. + """ + from auto_round.context.compress import CompressContext + from auto_round.context.model import ModelContext + + compress_context = CompressContext.get_context() + model_context = ModelContext.get_context() + + # Replace special characters to make the folder name filesystem-safe + sanitized_format = format.get_backend_name().replace(":", "-").replace("_", "-") + + formats = compress_context.formats + # Use a subfolder only if there are multiple formats + if len(formats) > 1: + return ( + os.path.join(compress_context.output_dir, sanitized_format, "transformer") + if compress_context.is_immediate_saving + else os.path.join(compress_context.output_dir, sanitized_format, "transformer") + ) + + # if use is_immediate_saving, we need to save model in self.output_dir/transformer folder + return ( + os.path.join(compress_context.output_dir, "transformer") + if compress_context.is_immediate_saving + else compress_context.output_dir + ) + + +def _get_save_folder_name(format, *args, **kwargs) -> str: + """Generates the save folder name based on the provided format string. + + If there are multiple formats to handle, the function creates a subfolder + named after the format string with special characters replaced. If there's + only one format, it returns the original output directory directly. + + Args: + format_str (str): The format identifier (e.g., 'gguf:q2_k_s'). + + Returns: + str: The path to the folder where results should be saved. + """ + from auto_round.context.compress import CompressContext + from auto_round.context.model import ModelContext + + compress_context = CompressContext.get_context() + model_context = ModelContext.get_context() + if model_context.is_diffusion: + return _get_diffusion_save_folder_name(format) + # Replace special characters to make the folder name filesystem-safe + sanitized_format = format.get_backend_name().replace(":", "-").replace("_", "-") + + # Use a subfolder only if there are multiple formats + if len(compress_context.formats) > 1: + return os.path.join(compress_context.output_dir, sanitized_format) + + return compress_context.output_dir + + +def immediate_pack(name: str, layer_config: dict): + from auto_round.context.compress import CompressContext + from auto_round.context.model import ModelContext + + compress_context = CompressContext.get_context() + model_context = ModelContext.get_context() + + if not compress_context.is_immediate_packing: + return + compress_context.formats[0].immediate_pack( + name=name, + model=model_context.model, + device=compress_context.device, + output_dir=_get_save_folder_name(compress_context.formats[0]), + layer_config=layer_config, + tokenizer=model_context.tokenizer, + mllm=model_context.is_mllm, + processor=getattr(model_context, "processor", None), + image_processor=getattr(model_context, "image_processor", None), + quant_nontext_module=getattr(model_context, "quant_nontext_module", False), + ) diff --git a/auto_round/compressors_new/zero_shot.py b/auto_round/compressors_new/zero_shot.py new file mode 100644 index 000000000..977244639 --- /dev/null +++ b/auto_round/compressors_new/zero_shot.py @@ -0,0 +1,257 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import Any, Union + +import torch +from tqdm import tqdm + +from auto_round.algorithms.alg_config import AlgConfig +from auto_round.compressors_new.base import BaseCompressor +from auto_round.compressors_new.utils import is_nv_fp, is_static_wfp8afp8 +from auto_round.logger import logger +from auto_round.modeling.fused_moe.replace_modules import materialize_model_ +from auto_round.utils import ( + check_to_quantized, + clear_memory, + convert_module_to_hp_if_necessary, + flatten_list, + get_block_names, + get_lm_head_name, + get_module, + global_state, + memory_monitor, + mv_module_from_gpu, + set_amax_for_all_moe_layers, + set_module, +) + + +class ZeroShotCompressor(BaseCompressor): + need_calib: bool = False + + def __init__( + self, + config: Union[AlgConfig, list[AlgConfig]], + model: Union[torch.nn.Module, str], + tokenizer=None, + platform="hf", + format=None, + low_gpu_mem_usage: bool = False, + device_map: Union[str, torch.device, int, dict] = 0, + enable_torch_compile: bool = False, + enable_alg_ext: bool = False, + seed: int = 42, + low_cpu_mem_usage: bool = True, + **kwargs, + ): + super().__init__( + config=config, + model=model, + tokenizer=tokenizer, + platform=platform, + format=format, + device_map=device_map, + low_gpu_mem_usage=low_gpu_mem_usage, + enable_torch_compile=enable_torch_compile, + enable_alg_ext=enable_alg_ext, + seed=seed, + low_cpu_mem_usage=low_cpu_mem_usage, + **kwargs, + ) + self.lr = 5e-3 + + def quantize_block( + self, + block: torch.nn.Module, + inputs: tuple, + q_input: Union[torch.Tensor, dict, None] = None, + device: Union[str, torch.device] = "cpu", + auto_offload: bool = True, + ): + """Quantize a single block via RTN (public API for LLM-Compressor). + + ZeroShotCompressor does not need calibration data, so ``inputs`` and + ``q_input`` are accepted for interface compatibility but not used for + algorithm purposes. The block is materialized, converted to the target + dtype, moved to ``device``, and quantized in-place via RTN. + + Returns: + tuple: ``(None, None)`` — RTN does not produce reference outputs. + """ + assert not self.mllm and not self.diffusion, ( + f"Currently, {self.__class__.__name__} does not support quantize_block " "for MLLM / diffusion models." + ) + + if not self._post_init_done: + self.post_init() + + materialize_model_(block) + convert_module_to_hp_if_necessary(block, self.model_context.amp_dtype, device) + block = block.to(device) + + self.quantizer.quantize_block(block) + + # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── + if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): + set_amax_for_all_moe_layers(block, attr_name="act_max") + + mv_module_from_gpu(block) + return None, None + + # Use no_grad instead of inference_mode + # https://github.com/intel/auto-round/issues/1620 + @torch.no_grad() + def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: + """Quantize the model and return the quantized model along with layer configurations.The entry of AutoRound. + Returns: + The quantized model and layer configurations. + """ + + self.post_init() + + formats = self.formats if isinstance(self.formats, list) else [] + if not (any(fmt.is_gguf() for fmt in formats) or self.super_bits is not None): + self._quantize_embedding_layer() # leave to gguf itself to handle + + # Release memory + clear_memory(device_list=self.device_list) + + # By default, we go with layer-wise way if no replacement happened. + # In RTN mode (iters == 0), force blockwise quantization to avoid + # full-model materialization and linear CPU RAM growth. + use_blockwise_quantization = global_state.replaced_module_count > 0 + if not use_blockwise_quantization: + logger.info( + "RTN mode detected (iters=0): force blockwise quantization to avoid " + "layer-wise full-model materialization." + ) + use_blockwise_quantization = True + tied_weights_keys = getattr(self.model, "_tied_weights_keys", []) + if tied_weights_keys is None: + tied_weights_keys = [] + if isinstance(tied_weights_keys, dict): + tied_weights_values = list(tied_weights_keys.values()) + else: + tied_weights_values = list(tied_weights_keys) + tied_weights_layers = [".".join(val.split(".")[:-1]) for val in tied_weights_values] # rm weight/bias + # In fact, we should detect whether it is is_separate_lm_head, to simplify, we don't do it + if getattr(self, "formats", None) and self.formats[0].is_gguf(): + lm_head_name = get_lm_head_name(self.model) + if lm_head_name is not None: + tied_weights_layers.append(lm_head_name) + + if use_blockwise_quantization: # The ram usage is a little higher + + all_blocks = self.quant_block_list or get_block_names(self.model) + pbar = tqdm(range(sum(len(block) for block in all_blocks))) + for block_names in all_blocks: + for block_name in block_names: + pbar.set_description(f"Quantizing {block_name}") + block = get_module(self.model, block_name) + + # ── Infrastructure: materialize ─────────────────────────── + materialize_model_(block) + + # ── Pure algorithm ──────────────────────────────────────── + self.quantizer.quantize_block(block) + + # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── + if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): + set_amax_for_all_moe_layers(block, attr_name="act_max") + + # ── Infrastructure: shard write / device cleanup ────────── + if self.compress_context.is_immediate_saving: + # Save non-quantized leaf modules (e.g. norms, embeddings in block). + for _n, m in block.named_modules(): + if ( + not any(m.children()) + and len(m.state_dict()) > 0 + and hasattr(m, "global_name") + and m.global_name not in tied_weights_layers + and not check_to_quantized(m) + ): + set_module(self.model, m.global_name, copy.deepcopy(m)) + self.shard_writer.write(name=m.global_name) + get_module(self.model, m.global_name).to("meta") + m.to("meta") + # Write at block scope for any remaining params/buffers. + self.shard_writer.write(name=block_name) + block.to("meta") + else: + mv_module_from_gpu(block) + if self.low_cpu_mem_usage: + self._offloader(self.model, block_name) + + clear_memory(device_list=self.device_list) + memory_monitor.log_summary() + pbar.update(1) + cnt = 1 + remain_layer_names = [] + block_name_set = set(name for block in all_blocks for name in block) + for n, m in self.model_context.model.named_modules(): + if not check_to_quantized(m): + continue + # Skip if this layer is part of any block (by prefix match) + if any(n == block_name or n.startswith(f"{block_name}.") for block_name in block_name_set): + continue + remain_layer_names.append(n) + for name in remain_layer_names: + logger.info(f"Quantizing remaining layer {name} on CPU.") + self.quantizer.quantize_layer(name) + cnt += 1 + if cnt % 10 == 0: + clear_memory(device_list=self.device_list) + memory_monitor.log_summary() + else: + all_to_quantized_module_names: list[str] = [ + n for n, m in self.model.named_modules() if check_to_quantized(m) + ] + all_to_quantized_module_names = all_to_quantized_module_names + materialize_model_(self.model) + self.model.to("cpu") + block_names_cnt = len(flatten_list(get_block_names(self.model, True))) + clear_mem_freq = len(all_to_quantized_module_names) // block_names_cnt + cnt = 0 + pbar = tqdm(all_to_quantized_module_names) + + for n, m in self.model.named_modules(): + if hasattr(m, "global_name") and m.global_name in all_to_quantized_module_names: + pbar.set_description(f"Quantizing {m.global_name}") + self.quantizer.quantize_layer(m.global_name) + cnt += 1 + pbar.update() + if cnt % clear_mem_freq == 0: + clear_memory(device_list=self.device_list) + memory_monitor.log_summary() + + elif ( + not any(m.children()) + and len(m.state_dict()) > 0 + and n not in tied_weights_layers + and self.compress_context.is_immediate_saving + ): + set_module(self.model, n, copy.deepcopy(m)) + self.shard_writer.write(name=n) + m.to("meta") + + # Convert remaining fp8 + convert_module_to_hp_if_necessary(self.model, self.amp_dtype, self.device) + if self.low_cpu_mem_usage: + self._offloader.reload(self.model) + if self.compress_context.is_immediate_saving: + self.shard_writer.write(is_finalize=True) + + self.model_context.quantized = True + return self.model, self.layer_config diff --git a/auto_round/context/__init__.py b/auto_round/context/__init__.py new file mode 100644 index 000000000..14a492441 --- /dev/null +++ b/auto_round/context/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/auto_round/context/base.py b/auto_round/context/base.py new file mode 100644 index 000000000..e3f75fb8f --- /dev/null +++ b/auto_round/context/base.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_round.logger import logger + + +class AutoSkipInitMeta(type): + + def __new__(mcs, name, bases, namespace): + if "__init__" in namespace: + original_init = namespace["__init__"] + + def wrapped_init(self, *args, **kwargs): + if getattr(self, "_singleton_skip_init", False): + return + original_init(self, *args, **kwargs) + self._singleton_skip_init = True + + namespace["__init__"] = wrapped_init + + namespace["_instances"] = {} + return super().__new__(mcs, name, bases, namespace) + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = cls.__new__(cls, *args, **kwargs) + cls._instances[cls] = instance + instance.__init__(*args, **kwargs) + + return cls._instances[cls] + + +class BaseContext(metaclass=AutoSkipInitMeta): + _instances = {} + + def __init__(self): + logger.info(f"{self.__class__.__name__} context initialized.") + + @classmethod + def get_context(cls): + assert cls in cls._instances, f"{cls.__name__} context has not been created yet." + return cls._instances.get(cls) + + @classmethod + def create_context(cls, *args, **kwargs): + return cls(*args, **kwargs) + + @classmethod + def reset_context(cls): + cls._instances.pop(cls, None) diff --git a/auto_round/context/compress.py b/auto_round/context/compress.py new file mode 100644 index 000000000..5b92a8b7c --- /dev/null +++ b/auto_round/context/compress.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +import torch + +from auto_round.context.base import BaseContext +from auto_round.utils.device import ( + clear_memory, + clear_memory_if_reached_threshold, + get_major_device, + parse_available_devices, + set_auto_device_map_for_block_with_tuning, + set_non_auto_device_map, +) + +__all__ = ["CompressContext"] + + +class CompressContext(BaseContext): + + def __init__( + self, + low_cpu_mem_usage: bool = True, + low_gpu_mem_usage: bool = False, + device_map: Union[str, torch.device, int, dict] = 0, + enable_torch_compile: bool = False, + is_immediate_packing: bool = False, + is_immediate_saving: bool = False, + formats: Union[list, str] = None, + output_dir: str = "./compressed_models", + static_kv_dtype: Optional[torch.dtype] = None, + static_attention_dtype: Optional[torch.dtype] = None, + **kwargs, + ): + super().__init__() + self.low_cpu_mem_usage = low_cpu_mem_usage + self.low_gpu_mem_usage = low_gpu_mem_usage + self.formats = formats + self.output_dir = output_dir + if device_map is None: + device_map = 0 + self.device_map = device_map + if isinstance(self.device_map, str): + self.device_map = self.device_map.replace(" ", "") + self.device_list = parse_available_devices(self.device_map) + self.device = get_major_device(self.device_map) + + self.cache_device = torch.device("cpu") if low_gpu_mem_usage else self.device + + self.enable_torch_compile = enable_torch_compile + self.immediate_packing = is_immediate_packing + self.is_immediate_packing = is_immediate_packing + self.is_immediate_saving = is_immediate_saving + self.static_kv_dtype = static_kv_dtype + self.static_attention_dtype = static_attention_dtype + + def clear_memory(self, tensor=None): + """Clear GPU/CPU memory only when ``low_gpu_mem_usage`` is enabled.""" + if self.low_gpu_mem_usage: + clear_memory(tensor, device_list=self.device_list) diff --git a/auto_round/context/model.py b/auto_round/context/model.py new file mode 100644 index 000000000..49648997e --- /dev/null +++ b/auto_round/context/model.py @@ -0,0 +1,292 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import importlib +from typing import Any, Callable, Optional, Union + +import torch +from packaging import version +from transformers import AutoConfig + +from auto_round import envs +from auto_round.compressors.utils import get_shared_keys +from auto_round.context.base import BaseContext +from auto_round.logger import logger +from auto_round.modeling.unfused_moe import apply_model_monkey_patches +from auto_round.special_model_handler import _handle_special_model, update_module +from auto_round.utils import ( + CpuInfo, + check_and_mark_quantized_module, + diffusion_load_model, + is_diffusion_model, + is_mllm_model, + is_moe_model, + is_moe_model_via_config, + llm_load_model, + mllm_load_model, + unsupported_meta_device, +) +from auto_round.utils.device import _force_trim_malloc + +__all__ = ["ModelContext"] + +_CUSTOM_MOE_REPLACEMENT_MODULES = { + "gpt_oss": "auto_round.modeling.fused_moe.gpt_oss", +} + + +class ModelContext(BaseContext): + _is_initialized = False + + # model_related + _model_loaded = False + _init_model = False + hook_handles = [] + + def __init__( + self, + model=None, + tokenizer=None, + platform="hf", + model_dtype=None, + trust_remote_code=True, + amp=True, + need_calib=True, + device="cpu", + formats=None, + is_act_quantize=False, + quant_nontext_module=False, + ): + super().__init__() + self.quantized = False + self.is_mllm = False + self.is_diffusion = False + self.is_model_patched = False + self.is_moe_model = False + + assert model is not None, "model must be provided for ModelContext" + self.model = model + self.tokenizer = tokenizer + self.device = device + + # MLLM / diffusion artifacts – always present so callers need no getattr guards. + # _load_model() will populate the ones that are relevant to the model type. + self.processor = None + self.image_processor = None + self.pipe = None + + if envs.AR_USE_MODELSCOPE: + platform = "model_scope" + self.platform = platform + self.model_dtype = model_dtype + self.trust_remote_code = trust_remote_code + self.amp = amp + self.need_calib = need_calib + self.quant_nontext_module = quant_nontext_module + + # Load model and run basic initialization eagerly so the model is ready + # by the time BaseCompressor.post_init() runs. + self._load_model() + + if unsupported_meta_device(self.model): + raise RuntimeError( + "AutoRound does not support parameters on meta device. " + "Please use more GPUs by setting `--device 0,1,2,3` or just place the model on CPU." + ) + check_and_mark_quantized_module(self.model) + self.model = self.model.eval() + self.shared_cache_keys = get_shared_keys(self.model) + + self.is_moe_model = is_moe_model(self.model) + self._import_custom_moe_replacements(getattr(self.model, "config", None)) + + self._set_amp_dtype() + if is_act_quantize and self.amp_dtype == torch.float16: + logger.warning("force to use bf16 for quantization tuning when enabling activation quantization") + self.amp_dtype = torch.bfloat16 + if self.model.dtype != torch.bfloat16: + self.model = self.model.to(torch.bfloat16) + else: + logger.info(f"using {self.model.dtype} for quantization tuning") + + # Reclaim C heap fragmentation left by model/tokenizer loading so + # that the quantize loop starts from a tighter RSS baseline. + gc.collect() + _force_trim_malloc() + + def _load_model(self): + if is_mllm_model(self.model, platform=self.platform): + self.is_mllm = True + if isinstance(self.model, str): + self.model, self.processor, self.tokenizer, self.image_processor = mllm_load_model( + self.model, platform=self.platform, device="cpu", model_dtype=self.model_dtype + ) + elif is_diffusion_model(self.model): + self.is_diffusion = True + self.pipe, self.model = diffusion_load_model( + self.model, platform=self.platform, device="cpu", model_dtype=self.model_dtype + ) + elif isinstance(self.model, str): + config: Optional[AutoConfig] = None + try: + config = AutoConfig.from_pretrained(self.model, trust_remote_code=self.trust_remote_code) + self._import_custom_moe_replacements(config) + except (OSError, EnvironmentError) as e: + logger.debug( + "Failed to load config via AutoConfig.from_pretrained for %s: %s. " + "Proceeding without config-based checks.", + self.model, + e, + ) + + self.is_model_patched = apply_model_monkey_patches( + model_name=self.model, trust_remote_code=self.trust_remote_code + ) + import transformers + + if ( + not self.is_model_patched + and config is not None + and is_moe_model_via_config(config) + and version.parse(transformers.__version__) >= version.parse("5.0.0") + ): + from auto_round.modeling.fused_moe.replace_modules import BUILTIN_MODULES + + model_type = getattr(config, "model_type", None) + if model_type is not None and model_type not in BUILTIN_MODULES: + logger.warning( + "This MoE model has not been optimized by AutoRound yet, which may result in high RAM usage, " + "Please consider submitting an issue to https://github.com/intel/auto-round/issues" + ) + + # Reclaim temporary HTTP/config objects from model type detection + # and AutoConfig loading before the large model allocation. This + # reduces heap fragmentation especially on HPU where habana internal + # allocations amplify fragmentation into persistent RSS growth. + gc.collect() + _force_trim_malloc() + + self.model, self.tokenizer = llm_load_model( + self.model, + platform=self.platform, + device="cpu", # always load cpu first + model_dtype=self.model_dtype, + trust_remote_code=self.trust_remote_code, + ) + elif self.tokenizer is None and not self.is_diffusion and self.need_calib: + raise ValueError("A tokenizer must be set for non-str model input") + + self._model_loaded = True + + def _import_custom_moe_replacements(self, model_or_config) -> None: + model_type = getattr(model_or_config, "model_type", None) + module_name = _CUSTOM_MOE_REPLACEMENT_MODULES.get(model_type) + if module_name is None: + return + + module = importlib.import_module(module_name) + from auto_round.modeling.fused_moe.replace_modules import BUILTIN_MODULES + + BUILTIN_MODULES.setdefault(model_type, module) + logger.debug(f"Loaded custom MoE replacement module for {model_type}") + + def _patch_custom_moe_modules(self) -> None: + model_type = getattr(getattr(self.model, "config", None), "model_type", None) + if model_type != "qwen3_vl_moe": + return + + for module in self.model.modules(): + if module.__class__.__name__ != "Qwen3VLMoeTextSparseMoeBlock": + continue + if hasattr(module, "top_k"): + continue + + gate = getattr(module, "gate", None) + top_k = getattr(gate, "top_k", None) + if top_k is not None: + setattr(module, "top_k", top_k) + + def _set_amp_dtype(self) -> None: + """Sets the automatic mixed precision (AMP) data type for the model based on the device and configuration.""" + self.amp_dtype = torch.bfloat16 + if self.model.dtype != torch.float32: + self.amp_dtype = self.model.dtype + if self.device == "cpu" or "hpu" in self.device: + self.amp_dtype = torch.bfloat16 + if self.amp: + if self.device == "cpu" and not CpuInfo().bf16: + self.amp = False + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + logger.warning( + f"amp is set to FALSE as the current {self.device} device does not support the 'bf16' data type." + ) + else: + if self.model.dtype != self.amp_dtype: + self.model = self.model.to(self.amp_dtype) + else: + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + + def apply_patches(self, formats): + """Apply format-specific model structure patches. + + Must be called after formats are resolved (list[OutputFormat]) and before + BaseQuantizers.post_init() so that configure_layer_config() operates on the + final model structure (post update_module). Eliminates the need for a + subsequent refresh_quantizer_for_initialized_model() call. + """ + # It is best to modify the model structure in the quantize function and check the format, + # because it may cause the gguf format to not be exported normally. + self._patch_custom_moe_modules() + self.model = update_module( + self.model, formats=formats, trust_remote_code=self.trust_remote_code, cleanup_original=False + ) + self.model = _handle_special_model(self.model) + + # Temporary names must be assigned after handle_moe_model; + # placing them earlier would cause them to be removed when the module is replaced. + for n, m in self.model.named_modules(): + m.global_name = n + + if self.amp and self.model.dtype != self.amp_dtype: + self.model = self.model.to(self.amp_dtype) + + self._init_model = True + self._is_initialized = True + + def replace_forward(self, register_hook): + """Replaces the forward function. + register_hook(layer_name, module, hook_handles) + """ + assert self._init_model, "should load and initialize model first" + hook_handles = [] + + for n, m in self.model.named_modules(): + register_hook(n, m, hook_handles) + + self.hook_handles = hook_handles + + def recover_forward(self): + """Recovers the forward function.""" + assert self._init_model, "should load and initialize model first" + + for n, m in self.model.named_modules(): + if hasattr(m, "orig_forward"): + m.forward = m.orig_forward + delattr(m, "orig_forward") + for hook_handle in self.hook_handles: + hook_handle.remove() + self.hook_handles = [] diff --git a/auto_round/envs.py b/auto_round/envs.py index a1a0e4f0d..e216a27e2 100644 --- a/auto_round/envs.py +++ b/auto_round/envs.py @@ -34,6 +34,7 @@ "AR_DISABLE_DATASET_SUBPROCESS": lambda: os.getenv("AR_DISABLE_DATASET_SUBPROCESS", "0").lower() in ("1", "true"), "AR_DISABLE_COPY_MTP_WEIGHTS": lambda: os.getenv("AR_DISABLE_COPY_MTP_WEIGHTS", "0").lower() in ("1", "true", "yes"), + "AR_DISABLE_NEW_ARCH": lambda: os.getenv("AR_DISABLE_NEW_ARCH", "0").lower() in ("1", "true", "yes"), "AR_ACT_SCALE": lambda: float(os.getenv("AR_ACT_SCALE", "1.0")), "AR_ENABLE_ACT_MINMAX_TUNING": lambda: os.getenv("AR_ENABLE_ACT_MINMAX_TUNING", "0").lower() in ("1", "true", "yes"), @@ -79,8 +80,8 @@ def set_config(**kwargs): for key, value in kwargs.items(): if key in environment_variables: # Convert value to appropriate string format - if key == "AR_USE_MODELSCOPE": - # Handle boolean values for AR_USE_MODELSCOPE + if key in ("AR_USE_MODELSCOPE", "AR_DISABLE_NEW_ARCH"): + # Handle boolean values for boolean env flags str_value = "true" if value in [True, "True", "true", "1", 1] else "false" else: # For other variables, convert to string diff --git a/auto_round/eval/evaluation.py b/auto_round/eval/evaluation.py index b3136c32f..e0af76396 100644 --- a/auto_round/eval/evaluation.py +++ b/auto_round/eval/evaluation.py @@ -388,8 +388,11 @@ def run_model_evaluation(model, tokenizer, autoround, folders, formats, device_s evaluate_diffusion_model(args, autoround=autoround, model=model) return - # Check if evaluation is needed for language models - eval_folder = folders[-1] if folders else None + # Check if evaluation Compressoris needed for language models + if isinstance(folders, list): + eval_folder = folders[-1] if folders else None + else: + eval_folder = folders if args.tasks is None or args.tasks == "" or eval_folder is None: return diff --git a/auto_round/experimental/apply_rotation_transform.py b/auto_round/experimental/apply_rotation_transform.py index fc7dbf297..520f6b2fb 100644 --- a/auto_round/experimental/apply_rotation_transform.py +++ b/auto_round/experimental/apply_rotation_transform.py @@ -1,154 +1,12 @@ -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -"""Unified entry point for Hadamard rotation/transform. - -Two backend implementations exist: - -* ``inplace`` – :mod:`auto_round.experimental.rotation_inplace` - QuaRot-style residual-stream rotation. Works for any weight/activation - dtype. Optionally fuses the online Hadamard into weights - (``fuse_online_to_weight=True``). -* ``transform`` – :mod:`auto_round.experimental.transform` - Per-Linear weight + activation Hadamard with a fused triton kernel. - Only supports MXFP4 / NVFP4 and **cannot** fuse online to weight. - -Routing is controlled by :class:`RotationConfig.backend`: - - "inplace" -> always inplace - "transform" -> always transform (validates dtype + no-fuse) - "auto" -> if user asked to fuse -> inplace - elif data_type is mx_fp / nv_fp -> transform - else -> inplace +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.dispatcher`. """ -from __future__ import annotations - -from typing import Any, Union - -import torch - -import auto_round.envs as envs -from auto_round.compressors.utils import is_mx_fp, is_nv_fp -from auto_round.experimental.transform.rotation_config import RotationConfig -from auto_round.experimental.utils import normalize_rotation_config -from auto_round.utils import logger - -__all__ = ["apply_hadamard_rotation", "resolve_hadamard_backend"] - - -def _to_config( - rotation_config: Union[str, dict, RotationConfig, None], - data_type: str, -) -> RotationConfig: - """Normalise *rotation_config* and return a :class:`RotationConfig` instance.""" - cfg_dict = normalize_rotation_config(rotation_config, data_type) - if isinstance(cfg_dict, RotationConfig): - return cfg_dict - return RotationConfig.model_validate(cfg_dict or {}) - - -def resolve_hadamard_backend(config: RotationConfig, data_type: str) -> str: - """Resolve the actual backend (``"inplace"`` / ``"transform"``) from config.""" - requested = config.backend - fuse_requested = bool(config.fuse_online_to_weight) - allow_online_rotation: bool = config.allow_online_rotation - - if requested == "inplace": - return "inplace" - - transform_backend_name = "transform" - if requested == "transform": - if fuse_requested: - raise ValueError( - f"backend='{transform_backend_name}' does not support fuse_online_to_weight=True. " - "Use backend='inplace' (or backend='auto' with fuse_online_to_weight=True) instead." - ) - if not (is_mx_fp(data_type) or is_nv_fp(data_type)): - raise ValueError( - f"backend='{transform_backend_name}' only supports MXFP4 / NVFP4 (got data_type={data_type!r}). " - "Use backend='inplace' or backend='auto' for other dtypes." - ) - if not allow_online_rotation: - raise ValueError(f"backend='{transform_backend_name}' only supports `allow_online_rotation`=True") - - return "transform" - - # backend == "auto" - if fuse_requested: - return "inplace" - if is_mx_fp(data_type) or is_nv_fp(data_type): - return "transform" - return "inplace" - - -def apply_hadamard_rotation( - model: torch.nn.Module, - rotation_config: Union[str, dict, RotationConfig, None], - data_type: str, - compute_device: torch.device | str = None, -) -> (torch.nn.Module, Any): - """Apply Hadamard rotation/transform to *model*, dispatching by backend. - - Args: - model: Target model. - rotation_config: ``str`` / ``dict`` / :class:`RotationConfig` / ``None``. - See :class:`RotationConfig` for fields. - data_type: Quantization data type (e.g. ``"mx_fp"``, ``"nv_fp"``, - ``"int"``, ``"fp"``). - compute_device: Device for inplace-backend computation. Ignored by - the transform backend. - - Returns: - The same model (for chaining); also stored on ``model.rotation_config``. - """ - config = _to_config(rotation_config, data_type) - backend = resolve_hadamard_backend(config, data_type) - - # Resolve fuse flag: explicit > env var > default(True) - fuse_online_to_weight = config.fuse_online_to_weight - if config.fuse_online_to_weight is not None: - fuse_online_to_weight = bool(config.fuse_online_to_weight) - elif envs.AR_FUSE_ONLINE_ROTATION: - fuse_online_to_weight = bool(envs.AR_FUSE_ONLINE_ROTATION) - - logger.info( - f"Applying Hadamard (backend={backend}, " - f"data_type={data_type}, fuse_online_to_weight={fuse_online_to_weight if backend == 'inplace' else False})." - ) - - if backend == "inplace": - logger.warning("this backend does not support real exporting, please export the model to fake format") - from auto_round.experimental.rotation_inplace import apply_rotation_transform - - # block_size -> group_size (None / -1 / 0 means full-dimension) - bs = config.block_size - group_size = bs if (bs is not None and bs > 0) else None - - model, hooks = apply_rotation_transform( - model, - group_size=group_size, - allow_online_rotation=config.allow_online_rotation, - rotation_matrix=config.hadamard_type, - fuse_online_to_weight=fuse_online_to_weight, - compute_device=compute_device, - ) - # Stash for downstream (export / serialization). Plain dict so JSON - # serialization (HF save_pretrained -> config.json) round-trips. - setattr(model, "rotation_config", config.model_dump() if hasattr(config, "model_dump") else config) - return model, hooks - - elif backend == "transform": - supported_hadamard_types = ("hadamard", "random_hadamard") - if config.hadamard_type not in supported_hadamard_types: - raise ValueError("this backend only supports hadamard or random_hadamard") - from auto_round.experimental.transform.apply import apply_rotation_transform - - return apply_rotation_transform(model, config, data_type=data_type) - else: - raise ValueError(f"Unsupported Hadamard backend {backend!r}") +from auto_round.algorithms.transforms.rotation.dispatcher import ( # noqa: F401 + apply_hadamard_rotation, + resolve_hadamard_backend, +) diff --git a/auto_round/experimental/rotation_inplace/__init__.py b/auto_round/experimental/rotation_inplace/__init__.py index 8cdef31b0..07b3d40c8 100644 --- a/auto_round/experimental/rotation_inplace/__init__.py +++ b/auto_round/experimental/rotation_inplace/__init__.py @@ -1,5 +1,14 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -from auto_round.experimental.rotation_inplace.apply_rotation_transform import apply_rotation_transform -from auto_round.experimental.rotation_inplace.utils import clear_random_hadamard_cache +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.inplace`. +""" + +from auto_round.algorithms.transforms.rotation.inplace.apply import ( # noqa: F401 + apply_rotation_transform, +) +from auto_round.algorithms.transforms.rotation.inplace.hooks import ( # noqa: F401 + clear_random_hadamard_cache, +) diff --git a/auto_round/experimental/rotation_inplace/apply_rotation_transform.py b/auto_round/experimental/rotation_inplace/apply_rotation_transform.py index 1052b036f..86e429c10 100644 --- a/auto_round/experimental/rotation_inplace/apply_rotation_transform.py +++ b/auto_round/experimental/rotation_inplace/apply_rotation_transform.py @@ -1,882 +1,12 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -"""Hadamard inplace rotation — public API and rotation primitives. - -Supports LLaMA-2, LLaMA-3, Qwen-3 (and any model with the same layout). -The entry point is :func:`apply_hadamard_rotation`. +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.inplace.apply`. """ -import gc -import typing -from typing import Dict, Union - -import torch -import tqdm - -from auto_round.experimental.rotation_inplace.model_config import ( - MAPPING_REGISTRY, - RotationMapping, - _resolve, - infer_mapping_from_model, -) -from auto_round.experimental.rotation_inplace.utils import ( - CrossHeadOnlineHadamardHook, - FullOnlineHadamardHook, - GroupOnlineHadamardHook, - _get_custom_had, - _normalize_rotation_matrix, - _resolve_compute_device, - _rotate_embedding_grouped, - _rotate_linear_grouped, - apply_cross_head_had_to_linear, - apply_exact_had_to_linear, - deterministic_hadamard_matrix, - get_hadK, - get_or_create_random_hadamard, +from auto_round.algorithms.transforms.rotation.inplace.apply import * # noqa: F401, F403 +from auto_round.algorithms.transforms.rotation.inplace.apply import ( # noqa: F401 + apply_rotation_transform, ) - -# --------------------------------------------------------------------------- -# Low-level primitives (model-agnostic via RotationMapping) -# --------------------------------------------------------------------------- - - -def _fuse_ln_linear( - layernorm: torch.nn.Module, - linear_layers: typing.Iterable[torch.nn.Linear], -) -> None: - """Fuse the linear operations in LayerNorm into adjacent linear blocks.""" - for linear in linear_layers: - linear_dtype = linear.weight.dtype - dev = linear.weight.device - - W_ = linear.weight.data.double() - ln_weight = layernorm.weight.double().to(dev) - linear.weight.data = (W_ * ln_weight).to(linear_dtype) - - if hasattr(layernorm, "bias") and layernorm.bias is not None: - if linear.bias is None: - linear.bias = torch.nn.Parameter(torch.zeros(linear.out_features, dtype=torch.float64, device=dev)) - ln_bias = layernorm.bias.double().to(dev) - linear.bias.data = linear.bias.data.double() + torch.matmul(W_, ln_bias) - linear.bias.data = linear.bias.data.to(linear_dtype) - - -def _reset_ln_params(layernorm: torch.nn.Module) -> None: - """Reset LayerNorm to identity: weight=1, bias=0.""" - layernorm.weight.data.fill_(1.0) - if hasattr(layernorm, "bias") and layernorm.bias is not None: - layernorm.bias.data.fill_(0.0) - - -def _rotate_linear_by_Q(module: torch.nn.Linear, Q: torch.Tensor, side: str, compute_device=None) -> None: - """Apply rotation *Q* to a Linear layer's weight (and bias if present). - - Args: - side: ``'input'`` → W = W @ Q (rotate input side) - ``'output'`` → W = Q^T @ W (rotate output side) - compute_device: Device to run computation on. If None, auto-detects GPU. - """ - dtype = module.weight.data.dtype - dev = module.weight.data.device - cdev = _resolve_compute_device(compute_device) - W_ = module.weight.data.to(device=cdev, dtype=torch.float64) - Q_ = Q.to(device=cdev) - if side == "input": - new_W = torch.matmul(W_, Q_).to(device=dev, dtype=dtype) - else: - new_W = torch.matmul(Q_.T, W_).to(device=dev, dtype=dtype) - # Release fp64 copy before assigning back so peak memory ≈ 1× weight + 1× rotated. - del W_ - module.weight.data = new_W - if side == "output" and module.bias is not None: - b = module.bias.data.to(device=cdev, dtype=torch.float64) - new_b = torch.matmul(Q_.T, b).to(device=dev, dtype=dtype) - del b - module.bias.data = new_b - del Q_ - - -def _untie_word_embeddings(model, mapping: RotationMapping) -> None: - """Break tied weights between lm_head and embedding if they share the same tensor.""" - embedding = _resolve(model, mapping.embedding) - lm_head = _resolve(model, mapping.lm_head) - - if lm_head.weight.data_ptr() != embedding.weight.data_ptr(): - return - - lm_head.weight = torch.nn.Parameter(lm_head.weight.data.clone()) - if hasattr(model.config, "tie_word_embeddings"): - model.config.tie_word_embeddings = False - - -def _uses_layernorm_with_mean(model, mapping: RotationMapping) -> bool: - """Check whether the model uses standard LayerNorm (which subtracts mean).""" - layers = _resolve(model, mapping.layers_attr) - first_ln = _resolve(layers[0], mapping.attn_input_ln) - return isinstance(first_ln, torch.nn.LayerNorm) - - -def _bake_mean_into_linear(linear: torch.nn.Linear) -> None: - """Subtract column-wise mean from a Linear layer's weight (and mean from bias).""" - linear_dtype = linear.weight.dtype - W_ = linear.weight.data.double() - linear.weight.data = (W_ - W_.mean(dim=-2, keepdim=True)).to(linear_dtype) - if linear.bias is not None: - b_ = linear.bias.data.double() - linear.bias.data = (b_ - b_.mean()).to(linear_dtype) - - -def _subtract_embedding_mean(model, mapping: RotationMapping) -> None: - """Subtract per-row mean from the embedding weight matrix.""" - W = _resolve(model, mapping.embedding) - dtype = W.weight.data.dtype - W_ = W.weight.data.to(dtype=torch.float64) - W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(dtype=dtype) - - if mapping.positional_embedding is not None: - P = _resolve(model, mapping.positional_embedding) - p_dtype = P.weight.data.dtype - P_ = P.weight.data.to(dtype=torch.float64) - P.weight.data = (P_ - P_.mean(dim=-1, keepdim=True)).to(dtype=p_dtype) - - -class _RMSNorm(torch.nn.Module): - """RMS Normalization (no mean subtraction).""" - - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.register_buffer("weight", torch.ones(dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - return x / rms * self.weight - - -def _replace_layernorms_with_rmsnorm(model) -> None: - """Replace all ``nn.LayerNorm`` modules with ``_RMSNorm``.""" - replacements = [] - for name, module in model.named_modules(): - if isinstance(module, torch.nn.LayerNorm): - replacements.append((name, module)) - - for name, module in replacements: - parts = name.rsplit(".", 1) - if len(parts) == 2: - parent = _resolve(model, parts[0]) - attr = parts[1] - else: - parent = model - attr = parts[0] - rms = _RMSNorm(module.normalized_shape[0], eps=module.eps) - rms = rms.to(device=module.weight.device, dtype=module.weight.dtype) - setattr(parent, attr, rms) - - -# --------------------------------------------------------------------------- -# High-level steps driven by RotationMapping -# --------------------------------------------------------------------------- - - -def _fuse_layer_norms(model, mapping: RotationMapping) -> None: - """Fuse all LayerNorm parameters into adjacent Linear layers.""" - layers = _resolve(model, mapping.layers_attr) - - for layer in layers: - mlp_ln = _resolve(layer, mapping.mlp_input_ln) - mlp_linears = [_resolve(layer, p) for p in mapping.mlp_in] - _fuse_ln_linear(mlp_ln, mlp_linears) - _reset_ln_params(mlp_ln) - - attn_ln = _resolve(layer, mapping.attn_input_ln) - attn_linears = [ - _resolve(layer, mapping.attn_q), - _resolve(layer, mapping.attn_k), - _resolve(layer, mapping.attn_v), - ] - _fuse_ln_linear(attn_ln, attn_linears) - _reset_ln_params(attn_ln) - - pre_head_ln = _resolve(model, mapping.pre_head_ln) - lm_head = _resolve(model, mapping.lm_head) - _fuse_ln_linear(pre_head_ln, [lm_head]) - _reset_ln_params(pre_head_ln) - - -# --------------------------------------------------------------------------- -# Unified weight rotation (full or grouped) -# --------------------------------------------------------------------------- - - -@torch.inference_mode() -def _rotate_weights( - model, - mapping: RotationMapping, - use_fast_had: bool = True, - group_size: int = None, - compute_device: torch.device = None, - had_dict: dict = None, - preset: str = None, - fuse_online_to_weight: bool = True, -) -> None: - """Apply Hadamard rotation to all weights. - - Args: - group_size: ``None`` → full Hadamard rotation. - ``int`` → block-diagonal rotation with this block size. - compute_device: Device to run Hadamard computation on (e.g. ``"cuda:0"``). - Weights are moved there temporarily and moved back afterwards. - If ``None``, auto-detects GPU availability. - allow_online_rotation: If ``True`` (default), apply extra input-side - Hadamard rotations on ``down_proj`` and the OV pair (``v_proj`` - output + ``o_proj`` input) that require compensating online hooks - at inference time. If ``False``, skip those extra rotations so - that **no** online hooks are needed. - had_dict: Normalized ``dict[int, Tensor]`` of custom Hadamard matrices - (keyed by dimension). Only used in grouped mode. - preset: Rotation preset name (``"quarot_hadamard"``, ``"hadamard"``, - ``"random_hadamard"``, or ``None``). - - * ``"quarot_hadamard"``: fusable (residual-stream) rotations use - ``fast_hadamard_transform`` / random Hadamard; non-fusable - (online-paired) rotations and their weight-side counterparts use - deterministic ``get_hadK``/``matmul_hadU`` so that the online - hook at inference produces the exact same transform. - * ``"hadamard"``: all rotations use deterministic ``get_hadK`` / - ``matmul_hadU``. Full-mode Q is a deterministic Hadamard matrix. - * ``"random_hadamard"``: all rotations use random Hadamard matrices - from the global cache (``get_or_create_random_hadamard``). - Same dimension → same matrix everywhere. - * ``None``: same behaviour as ``"hadamard"`` (built-in butterfly). - """ - compute_device = _resolve_compute_device(compute_device) - config = model.config - hidden_size = getattr(config, mapping.hidden_size_attr) - intermediate_size = getattr(config, mapping.intermediate_size_attr) - num_heads = getattr(config, mapping.num_heads_attr) - head_dim = mapping.attn_head_dim or (hidden_size // num_heads) - - is_grouped = group_size is not None and group_size > 0 - desc = f"Rotating (group_size={group_size})" if is_grouped else "Rotating" - - # ----- Resolve per-operation Hadamard sources ----- - fused_fast = use_fast_had - online_fast = False - if preset == "random_hadamard": - fused_fast = False - - # -- Matrix resolution -- - had_matrix, _found = _get_custom_had(had_dict, group_size) if is_grouped else (None, False) - - online_had_matrix = had_matrix - if preset == "random_hadamard" and had_matrix is None: - had_matrix = get_or_create_random_hadamard(group_size if is_grouped else hidden_size, compute_device) - online_had_matrix = had_matrix - if preset == "quarot_hadamard" and is_grouped: - online_had_matrix = None # force deterministic for online-paired - - # -- Helper: look up cached random matrix for online-paired ops -- - def _online_had(dim): - """Return cached random matrix for *dim* under random_hadamard, else None.""" - if preset == "random_hadamard": - return get_or_create_random_hadamard(dim, compute_device) - return None - - if is_grouped: - assert hidden_size % group_size == 0, f"group_size={group_size} must divide hidden_size={hidden_size}" - assert ( - intermediate_size % group_size == 0 - ), f"group_size={group_size} must divide intermediate_size={intermediate_size}" - - # --- Full mode: build Hadamard matrix Q --- - Q = None - if not is_grouped: - if preset == "hadamard": - Q = deterministic_hadamard_matrix(hidden_size, compute_device) - else: - # "random_hadamard", "quarot_hadamard", None — same shape → same matrix - Q = get_or_create_random_hadamard(hidden_size, compute_device) - - # ---- Top-level: embedding / lm_head ---- - # When fuse_online_to_weight=False, skip embedding and lm_head rotation: - # each layer is self-contained (weight rotation + online hook cancel out). - if fuse_online_to_weight: - embedding = _resolve(model, mapping.embedding) - if is_grouped: - _rotate_embedding_grouped( - embedding, group_size, use_fast_had=fused_fast, compute_device=compute_device, had_matrix=had_matrix - ) - else: - dtype = embedding.weight.data.dtype - dev = embedding.weight.data.device - cdev = compute_device - W_ = embedding.weight.data.to(device=cdev, dtype=torch.float64) - new_W = torch.matmul(W_, Q.to(cdev)).to(device=dev, dtype=dtype) - del W_ - embedding.weight.data = new_W - - if mapping.positional_embedding is not None: - pos_emb = _resolve(model, mapping.positional_embedding) - if is_grouped: - _rotate_embedding_grouped( - pos_emb, group_size, use_fast_had=fused_fast, compute_device=compute_device, had_matrix=had_matrix - ) - else: - pos_dtype = pos_emb.weight.data.dtype - pos_dev = pos_emb.weight.data.device - cdev = compute_device - P_ = pos_emb.weight.data.to(device=cdev, dtype=torch.float64) - new_P = torch.matmul(P_, Q.to(cdev)).to(device=pos_dev, dtype=pos_dtype) - del P_ - pos_emb.weight.data = new_P - - # ---- Top-level: lm_head ---- - lm_head = _resolve(model, mapping.lm_head) - if is_grouped: - _rotate_linear_grouped( - lm_head, - group_size, - side="input", - use_fast_had=fused_fast, - compute_device=compute_device, - had_matrix=had_matrix, - ) - else: - _rotate_linear_by_Q(lm_head, Q, side="input", compute_device=compute_device) - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # ---- Per-layer rotation ---- - layers = _resolve(model, mapping.layers_attr) - for layer in tqdm.tqdm(layers, unit="layer", desc=desc): - if fuse_online_to_weight: - # ---- fuse mode: QuaRot-style residual stream rotation ---- - # Q/K/V: only residual Q on input (no online Had stacking, no hook). - # When Q == online Had (e.g. preset="hadamard"), Q @ Q = I cancels - # the rotation entirely, destroying quantization benefit. - # gate/up: only residual Q on input (no online Had stacking, no hook). - # down_proj: residual Q^T on output + online Had on input (+ hook). - # v_proj/o_proj: per-head/cross-head Had below (+ hook on o_proj). - for attr in (mapping.attn_q, mapping.attn_k, mapping.attn_v): - mod = _resolve(layer, attr) - if is_grouped: - _rotate_linear_grouped( - mod, - group_size, - side="input", - use_fast_had=fused_fast, - compute_device=compute_device, - had_matrix=had_matrix, - ) - else: - _rotate_linear_by_Q(mod, Q, side="input", compute_device=compute_device) - - # o_proj: residual stream output rotation - if is_grouped: - _rotate_linear_grouped( - _resolve(layer, mapping.attn_o), - group_size, - side="output", - use_fast_had=fused_fast, - compute_device=compute_device, - had_matrix=had_matrix, - ) - else: - _rotate_linear_by_Q(_resolve(layer, mapping.attn_o), Q, side="output", compute_device=compute_device) - - # gate/up: only residual Q on input - for attr in mapping.mlp_in: - mod = _resolve(layer, attr) - if is_grouped: - _rotate_linear_grouped( - mod, - group_size, - side="input", - use_fast_had=fused_fast, - compute_device=compute_device, - had_matrix=had_matrix, - ) - else: - _rotate_linear_by_Q(mod, Q, side="input", compute_device=compute_device) - - # down_proj: residual output + online input Had - down_proj = _resolve(layer, mapping.mlp_out) - if is_grouped: - _rotate_linear_grouped( - down_proj, - group_size, - side="output", - use_fast_had=fused_fast, - compute_device=compute_device, - had_matrix=had_matrix, - ) - _rotate_linear_grouped( - down_proj, - group_size, - side="input", - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=online_had_matrix, - ) - else: - _rotate_linear_by_Q(down_proj, Q, side="output", compute_device=compute_device) - apply_exact_had_to_linear( - down_proj, - had_dim=-1, - output=False, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=_online_had(intermediate_size), - ) - - # OV projection: v_proj per-head output + o_proj full/cross-head input - v_proj = _resolve(layer, mapping.attn_v) - o_proj = _resolve(layer, mapping.attn_o) - if is_grouped: - pass - else: - online_head_had = _online_had(head_dim) - apply_exact_had_to_linear( - v_proj, - had_dim=head_dim, - output=True, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=online_head_had, - ) - if preset == "random_hadamard": - apply_exact_had_to_linear( - o_proj, - had_dim=head_dim, - output=False, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=online_head_had, - ) - apply_cross_head_had_to_linear( - o_proj, - num_heads, - head_dim, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=_online_had(num_heads), - ) - else: - apply_exact_had_to_linear( - o_proj, - had_dim=-1, - output=False, - use_fast_had=online_fast, - compute_device=compute_device, - ) - - else: - # ---- unfused mode: no residual rotation, only input-side Had ---- - # Each layer gets Had fused on input side + compensating hook → equivalent. - # No embedding/lm_head rotation. No self-cancelling pair. - # v_proj treated same as Q/K (input Had only, no per-head/cross-head). - - # Q/K/V: input-side Had on hidden_size - for attr in (mapping.attn_q, mapping.attn_k, mapping.attn_v): - mod = _resolve(layer, attr) - if is_grouped: - _rotate_linear_grouped( - mod, - group_size, - side="input", - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=online_had_matrix, - ) - else: - apply_exact_had_to_linear( - mod, - had_dim=-1, - output=False, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=_online_had(hidden_size), - ) - - # o_proj: input-side Had on hidden_size (full Had, not cross-head) - o_proj = _resolve(layer, mapping.attn_o) - if is_grouped: - _rotate_linear_grouped( - o_proj, - group_size, - side="input", - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=online_had_matrix, - ) - else: - apply_exact_had_to_linear( - o_proj, - had_dim=-1, - output=False, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=_online_had(hidden_size), - ) - - # gate/up: input-side Had on hidden_size - for attr in mapping.mlp_in: - mod = _resolve(layer, attr) - if is_grouped: - _rotate_linear_grouped( - mod, - group_size, - side="input", - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=online_had_matrix, - ) - else: - apply_exact_had_to_linear( - mod, - had_dim=-1, - output=False, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=_online_had(hidden_size), - ) - - # down_proj: input-side Had on intermediate_size - down_proj = _resolve(layer, mapping.mlp_out) - if is_grouped: - _rotate_linear_grouped( - down_proj, - group_size, - side="input", - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=online_had_matrix, - ) - else: - apply_exact_had_to_linear( - down_proj, - had_dim=-1, - output=False, - use_fast_had=online_fast, - compute_device=compute_device, - had_matrix=_online_had(intermediate_size), - ) - - # Per-layer cleanup: drop fp64 temporaries and CUDA caching allocator - # blocks so peak memory stays at ~1 layer's worth instead of accumulating - # across all 32+ decoder layers (was the main cause of 33 GB RAM on 8B). - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -# --------------------------------------------------------------------------- -# Unified online hook registration -# --------------------------------------------------------------------------- - - -def _register_online_hooks( - model, - mapping: RotationMapping, - fp32_had: bool = False, - use_fast_had: bool = True, - group_size: int = None, - had_dict: dict = None, - preset: str = None, - fuse_online_to_weight: bool = True, -): - """Register online Hadamard pre-forward hooks on ``down_proj`` and ``o_proj``. - - Online hooks must use the **same** Hadamard matrix that was applied to the - weight-side counterpart during ``_rotate_weights``. For ``quarot_hadamard`` - this is always the deterministic ``get_hadK``/``matmul_hadU`` path - (``use_fast_had=False``). For ``"random_hadamard"`` it is the random matrix that - was generated once and stored in ``had_dict``. - - Args: - group_size: ``None`` → full Hadamard hooks (original QuaRot). - ``int`` → per-group Hadamard hooks. - had_dict: Normalized ``dict[int, Tensor]`` of custom Hadamard matrices. - preset: Rotation preset name. - Returns: - list of hook handles. - """ - config = model.config - num_heads = getattr(config, mapping.num_heads_attr) - hidden_size = getattr(config, mapping.hidden_size_attr) - intermediate_size = getattr(config, mapping.intermediate_size_attr) - head_dim = mapping.attn_head_dim or (hidden_size // num_heads) - - is_grouped = group_size is not None and group_size > 0 - - # Online hooks always use deterministic (fixed) Hadamard — never fast_had - # for quarot_hadamard; for "random_hadamard" they use the same random matrix - # that was cached in had_dict by _rotate_weights. - online_fast = False - - # -- Matrix resolution (must match the *online-paired* matrix used by - # _rotate_weights for down_proj input / OV pair). Variable name kept in - # sync with _rotate_weights to make any future drift obvious. - online_had_matrix, _ = _get_custom_had(had_dict, group_size) if is_grouped else (None, False) - if preset == "random_hadamard" and online_had_matrix is None: - online_had_matrix = get_or_create_random_hadamard(group_size if is_grouped else hidden_size) - if preset == "quarot_hadamard" and is_grouped: - online_had_matrix = None - - # -- Helper: look up cached random matrix for online-paired hooks -- - def _online_had(dim): - if preset == "random_hadamard": - return get_or_create_random_hadamard(dim) - return None - - mlp_out_suffix = mapping.mlp_out.split(".")[-1] - attn_o_suffix = mapping.attn_o.split(".")[-1] - - # Suffixes for Q/K/V and gate/up (for online input Had hooks) - attn_qkv_suffixes = set(attr.split(".")[-1] for attr in (mapping.attn_q, mapping.attn_k, mapping.attn_v)) - mlp_in_suffixes = set(attr.split(".")[-1] for attr in mapping.mlp_in) - - # --- Build hook factories --- - def _make_down_proj_hook(): - if is_grouped: - return GroupOnlineHadamardHook( - group_size=group_size, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_had_matrix - ) - online_mat = _online_had(intermediate_size) - if online_mat is not None: - return FullOnlineHadamardHook( - had_K=None, K=None, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_mat - ) - had_K, K = get_hadK(intermediate_size) - return FullOnlineHadamardHook(had_K=had_K, K=K, fp32_had=fp32_had, use_fast_had=online_fast) - - def _make_hidden_had_hook(): - """Full Had hook on hidden_size (for Q/K/V and gate/up input).""" - if is_grouped: - return GroupOnlineHadamardHook( - group_size=group_size, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_had_matrix - ) - online_mat = _online_had(hidden_size) - if online_mat is not None: - return FullOnlineHadamardHook( - had_K=None, K=None, fp32_had=fp32_had, use_fast_had=online_fast, had_matrix=online_mat - ) - had_K, K = get_hadK(hidden_size) - return FullOnlineHadamardHook(had_K=had_K, K=K, fp32_had=fp32_had, use_fast_had=online_fast) - - def _make_o_proj_hook(): - online_mat = _online_had(num_heads) - if online_mat is not None: - return CrossHeadOnlineHadamardHook( - had_K=None, - K=None, - head_dim=head_dim, - fp32_had=fp32_had, - use_fast_had=online_fast, - had_matrix=online_mat, - ) - had_K, K = get_hadK(num_heads) - return CrossHeadOnlineHadamardHook( - had_K=had_K, - K=K, - head_dim=head_dim, - fp32_had=fp32_had, - use_fast_had=online_fast, - ) - - # --- Register --- - handles = [] - - for name, module in model.named_modules(): - if not isinstance(module, torch.nn.Linear): - continue - suffix = name.split(".")[-1] - - if name.endswith(mlp_out_suffix): - # down_proj: full Had on intermediate_size input - h = module.register_forward_pre_hook(_make_down_proj_hook()) - handles.append(h) - elif name.endswith(attn_o_suffix): - if fuse_online_to_weight and not is_grouped: - # o_proj: cross-head Had on input (fused mode, full only) - h = module.register_forward_pre_hook(_make_o_proj_hook()) - handles.append(h) - elif not fuse_online_to_weight: - # o_proj: full Had on hidden_size input (unfused mode, matches weight rotation) - h = module.register_forward_pre_hook(_make_hidden_had_hook()) - handles.append(h) - elif suffix in attn_qkv_suffixes: - if not fuse_online_to_weight: - # Q/K/V: full Had on hidden_size input (unfused mode only). - # In fused mode Q/K/V only have residual Q on weight (no online Had), - # and activations come pre-rotated from residual stream → no hook needed. - h = module.register_forward_pre_hook(_make_hidden_had_hook()) - handles.append(h) - elif suffix in mlp_in_suffixes: - if not fuse_online_to_weight: - # gate/up: full Had on hidden_size input (unfused mode only). - # Same reasoning as Q/K/V above. - h = module.register_forward_pre_hook(_make_hidden_had_hook()) - handles.append(h) - - return handles - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -def apply_rotation_transform( - model, - group_size: int = None, - allow_online_rotation: bool = True, - rotation_matrix: Union[str, torch.Tensor, Dict[int, torch.Tensor], None] = None, - compute_device: torch.device | str = None, - fp32_had: bool = False, - fuse_online_to_weight: bool = None, -): - """Fuse layer norms, rotate weights, and register online Hadamard hooks. - - This is the single entry point for applying Hadamard inplace rotation. - The model architecture is auto-detected via ``model.config.model_type``. - - Args: - model: A HuggingFace CausalLM model (LLaMA-2/3, Qwen-3, etc.). - fp32_had: Whether to compute the online Hadamard transform in fp32. - group_size: If ``None`` (default), use full-dimension Hadamard rotation. - compute_device: Device to run Hadamard computation on. - allow_online_rotation: If ``True`` (default), apply online Hadamard - rotations on ``down_proj`` input and the OV pair. - rotation_matrix: Rotation matrix selection (``"hadamard"``, - ``"random_hadamard"``, ``"quarot_hadamard"``, Tensor, dict, or None). - fuse_online_to_weight: If ``True`` (default), fuse online Hadamard - rotation into weights (down_proj input, v_proj output, o_proj input) - and register compensating online hooks. If ``False``, skip - embedding/lm_head rotation; each linear layer is self-contained - with input-side Had on weight + compensating online hook on - activation. No v_proj cross-head or inner-head rotation. - - Returns: - list of hook handles.""" - if fuse_online_to_weight is None: - if model.config.model_type in MAPPING_REGISTRY or model.__class__.__name__ in MAPPING_REGISTRY: - fuse_online_to_weight = True - else: - fuse_online_to_weight = False - had_dict, use_fast_had, preset = _normalize_rotation_matrix(rotation_matrix, group_size) - compute_device = _resolve_compute_device(compute_device) - - if use_fast_had: - from auto_round.utils import logger - - try: - import fast_hadamard_transform # noqa: F401 - - if group_size is None: - logger.warning( - "fast_hadamard_transform uses a different Hadamard matrix than the " - "default implementation. Please ensure consistency between training " - "and inference. This will be refined later." - ) - except ImportError: - logger.warning("Importing fast_hadamard_transform failed, falling back to default implementation.") - use_fast_had = False - - mapping = infer_mapping_from_model(model) - - _untie_word_embeddings(model, mapping) - - if _uses_layernorm_with_mean(model, mapping): - _subtract_embedding_mean(model, mapping) - - _fuse_layer_norms(model, mapping) - - if _uses_layernorm_with_mean(model, mapping): - layers = _resolve(model, mapping.layers_attr) - for layer in layers: - _bake_mean_into_linear(_resolve(layer, mapping.attn_o)) - _bake_mean_into_linear(_resolve(layer, mapping.mlp_out)) - _replace_layernorms_with_rmsnorm(model) - - _rotate_weights( - model, - mapping, - use_fast_had=use_fast_had, - group_size=group_size, - compute_device=compute_device, - had_dict=had_dict, - preset=preset, - fuse_online_to_weight=fuse_online_to_weight, - ) - - handles = [] - if fuse_online_to_weight or allow_online_rotation: - handles = _register_online_hooks( - model, - mapping, - fp32_had=fp32_had, - use_fast_had=use_fast_had, - group_size=group_size, - had_dict=had_dict, - preset=preset, - fuse_online_to_weight=fuse_online_to_weight, - ) - - return model, handles - - -# --------------------------------------------------------------------------- -# Quick smoke test -# --------------------------------------------------------------------------- - -if __name__ == "__main__": - from transformers import AutoModelForCausalLM, AutoTokenizer - - model_name = "/models/opt-125m" - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") - model.to("cuda") - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) - - apply_rotation_transform( - model, group_size=-1, allow_online_rotation=True, rotation_matrix="random_hadamard", fuse_online_to_weight=False - ) - model.to("cuda") - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) - - model_name = "/models/Qwen3-8B" - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") - apply_rotation_transform(model, group_size=-1, allow_online_rotation=True, fuse_online_to_weight=True) - model.to("cuda") - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) - - from transformers import AutoModelForCausalLM, AutoTokenizer - - model_name = "/models/Meta-Llama-3.1-8B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") - apply_rotation_transform(model, fuse_online_to_weight=True, group_size=32) - model.to("cuda") - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) - # - # model_name = "/models/Llama-2-7b-chat-hf" - # tokenizer = AutoTokenizer.from_pretrained(model_name) - # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") - # apply_hadamard_rotation(model) - # model.to("cuda") - # text = "There is a girl who likes adventure," - # inputs = tokenizer(text, return_tensors="pt").to(model.device) - # print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) diff --git a/auto_round/experimental/rotation_inplace/model_config.py b/auto_round/experimental/rotation_inplace/model_config.py index 3ecbf9b69..35078cd28 100644 --- a/auto_round/experimental/rotation_inplace/model_config.py +++ b/auto_round/experimental/rotation_inplace/model_config.py @@ -1,169 +1,9 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -"""Model architecture mapping for Hadamard rotation. - -Each :class:`RotationMapping` describes *where* the rotation-relevant modules -live inside a model. Currently supports LLaMA-2, LLaMA-3, and Qwen-3 (dense). - -New architectures can be supported by calling :func:`register_mapping`. +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.inplace.model_config`. """ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Dict, List, Optional - -from auto_round.utils import logger - -__all__ = [ - "RotationMapping", - "register_mapping", - "get_mapping", - "infer_mapping_from_model", - "MAPPING_REGISTRY", -] - - -# --------------------------------------------------------------------------- -# Mapping dataclass -# --------------------------------------------------------------------------- - - -@dataclass -class RotationMapping: - """Declarative description of a transformer architecture for Hadamard rotation. - - Attribute names follow the dot-path convention relative to the model or - each decoder layer. - - Config attribute names (read from ``model.config``): - num_heads_attr, hidden_size_attr, intermediate_size_attr - head_dim_override – explicit head dim (skip hidden_size // num_heads) - """ - - # -- top-level modules (dot-path from model root) -- - embedding: str = "model.embed_tokens" - lm_head: str = "lm_head" - positional_embedding: Optional[str] = None # e.g. "model.decoder.embed_positions" for OPT - - # -- layers container (dot-path from model root) -- - layers_attr: str = "model.layers" - - # -- per-layer: attention (dot-path from each layer) -- - attn_input_ln: str = "input_layernorm" - attn_q: str = "self_attn.q_proj" - attn_k: str = "self_attn.k_proj" - attn_v: str = "self_attn.v_proj" - attn_o: str = "self_attn.o_proj" - - # -- per-layer: MLP (dot-path from each layer) -- - mlp_input_ln: str = "post_attention_layernorm" - mlp_in: List[str] = field(default_factory=lambda: ["mlp.up_proj", "mlp.gate_proj"]) - mlp_out: str = "mlp.down_proj" - - # -- final norm (dot-path from model root) -- - pre_head_ln: str = "model.norm" - - # -- head dim override (None = hidden_size // num_heads) -- - attn_head_dim: Optional[int] = None - - # -- config attr names -- - num_heads_attr: str = "num_attention_heads" - hidden_size_attr: str = "hidden_size" - intermediate_size_attr: str = "intermediate_size" - - -# --------------------------------------------------------------------------- -# Helper: resolve a dot-path attribute on a module -# --------------------------------------------------------------------------- - - -def _resolve(root, dot_path: str): - """Resolve ``'a.b.c'`` to ``root.a.b.c``.""" - obj = root - for attr in dot_path.split("."): - obj = getattr(obj, attr) - return obj - - -# --------------------------------------------------------------------------- -# Registry -# --------------------------------------------------------------------------- - -MAPPING_REGISTRY: Dict[str, RotationMapping] = {} - - -def register_mapping(key: str, mapping: RotationMapping) -> RotationMapping: - """Register a :class:`RotationMapping` under *key* (model_type or architecture).""" - MAPPING_REGISTRY[key] = mapping - return mapping - - -def get_mapping(key: str) -> RotationMapping: - """Look up a mapping by *key*; fall back to default if not found.""" - if key in MAPPING_REGISTRY: - return MAPPING_REGISTRY[key] - logger.warning(f"No rotation mapping registered for '{key}', " "falling back to default (LLaMA-like) mapping.") - return RotationMapping() - - -def infer_mapping_from_model(model) -> RotationMapping: - """Return the best :class:`RotationMapping` for *model*. - - Tries ``model.config.model_type`` first, then ``model.__class__.__name__``. - """ - model_type = getattr(getattr(model, "config", None), "model_type", "") - if model_type in MAPPING_REGISTRY: - return MAPPING_REGISTRY[model_type] - - arch = model.__class__.__name__ - if arch in MAPPING_REGISTRY: - return MAPPING_REGISTRY[arch] - - logger.warning( - f"Unrecognised architecture '{arch}' (model_type='{model_type}'). " - "Falling back to default (LLaMA-like) mapping." - ) - return RotationMapping() - - -# =================================================================== -# Built-in mappings -# =================================================================== - -# LLaMA-2 / LLaMA-3 / Mistral / Yi — all share the same layout -_default = RotationMapping() - -register_mapping("llama", _default) -register_mapping("LlamaForCausalLM", _default) - -# Qwen-3 dense — identical layout to LLaMA -register_mapping("qwen3", _default) -register_mapping("Qwen3ForCausalLM", _default) - -# Qwen-2 / Qwen-2.5 dense — identical layout to LLaMA -register_mapping("qwen2", _default) -register_mapping("Qwen2ForCausalLM", _default) - -# ---- OPT ---- -# OPT uses standard LayerNorm (with bias, subtracts mean), -# different module names, and tied lm_head ↔ embedding weights. -_opt = RotationMapping( - embedding="model.decoder.embed_tokens", - lm_head="lm_head", - positional_embedding="model.decoder.embed_positions", - layers_attr="model.decoder.layers", - attn_input_ln="self_attn_layer_norm", - attn_q="self_attn.q_proj", - attn_k="self_attn.k_proj", - attn_v="self_attn.v_proj", - attn_o="self_attn.out_proj", - mlp_input_ln="final_layer_norm", - mlp_in=["fc1"], - mlp_out="fc2", - pre_head_ln="model.decoder.final_layer_norm", - intermediate_size_attr="ffn_dim", -) -register_mapping("opt", _opt) -register_mapping("OPTForCausalLM", _opt) +from auto_round.algorithms.transforms.rotation.inplace.model_config import * # noqa: F401, F403 diff --git a/auto_round/experimental/rotation_inplace/utils.py b/auto_round/experimental/rotation_inplace/utils.py index 04bb18981..4ddd80d48 100644 --- a/auto_round/experimental/rotation_inplace/utils.py +++ b/auto_round/experimental/rotation_inplace/utils.py @@ -1,786 +1,9 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -"""Online Hadamard transform hooks. - -After weight rotation, down_proj and o_proj require an online Hadamard -transform on their *input activations* at inference time. This module -provides the hooks and a helper to register them on the model. +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.inplace.hooks`. """ -import math - -import torch -import torch.nn as nn - -try: - import fast_hadamard_transform -except ImportError: - fast_hadamard_transform = None - - -def _resolve_compute_device(compute_device) -> torch.device: - """Return *compute_device* if explicitly given, otherwise auto-detect GPU. - - When ``compute_device`` is ``None`` the function checks for CUDA / XPU - availability and returns the first accelerator it finds so that heavy - matrix operations are offloaded to GPU even when the model weights live - on CPU. Falls back to ``torch.device("cpu")`` when no accelerator is - present. - """ - if compute_device is not None: - return torch.device(compute_device) if not isinstance(compute_device, torch.device) else compute_device - if torch.cuda.is_available(): - return torch.device("cuda:0") - if hasattr(torch, "xpu") and torch.xpu.is_available(): - return torch.device("xpu:0") - return torch.device("cpu") - - -BUILTIN_ROTATION_PRESETS = {"quarot_hadamard", "hadamard", "random_hadamard"} - -# Global cache for random Hadamard matrices keyed by dimension. -# Ensures the same shape always returns the exact same random matrix within -# a process, across all calls to ``_rotate_weights`` / ``_register_online_hooks``. -_RANDOM_HADAMARD_CACHE: dict = {} - - -def get_or_create_random_hadamard(dim: int, device=None) -> torch.Tensor: - """Return a random Hadamard matrix for *dim*, creating and caching it if needed. - - The matrix is cached globally in ``_RANDOM_HADAMARD_CACHE`` so that every - caller that requests the same *dim* receives the identical matrix. - """ - if dim in _RANDOM_HADAMARD_CACHE: - mat = _RANDOM_HADAMARD_CACHE[dim] - if device is not None: - mat = mat.to(device) - return mat - mat = random_hadamard_matrix(dim, device or torch.device("cpu")) - _RANDOM_HADAMARD_CACHE[dim] = mat - return mat - - -def clear_random_hadamard_cache(): - """Clear the global random Hadamard matrix cache. - - Call this when you want subsequent ``random_hadamard`` preset runs to - generate fresh random matrices (e.g. between independent experiments). - """ - _RANDOM_HADAMARD_CACHE.clear() - - -def _normalize_rotation_matrix(rotation_matrix, group_size): - """Normalize ``rotation_matrix`` into a ``(had_dict, use_fast_had, preset)`` tuple. - - Accepted inputs: - * ``None`` → ``(None, False, None)`` — use built-in butterfly ``matmul_hadU``. - * ``"quarot_hadamard"`` → ``(None, True, "quarot_hadamard")`` — fusable - rotations use ``fast_hadamard_transform`` (random); non-fusable - (online-paired) rotations use deterministic ``get_hadK``/``matmul_hadU``. - * ``"hadamard"`` → ``(None, False, "hadamard")`` — all rotations use - deterministic ``get_hadK``/``matmul_hadU``. - * ``"random_hadamard"`` → ``(None, False, "random_hadamard")`` — all rotations use - ``random_hadamard_matrix``. - * A ``torch.Tensor`` of shape ``(n, n)`` → ``({n: tensor}, False, None)``. - * A ``dict[int, Tensor]`` → ``(dict, False, None)`` — returned as-is. - - Returns: - ``(had_dict, use_fast_had, preset)`` - - Raises: - ValueError: if a non-``str`` *rotation_matrix* is given but - *group_size* is not a positive integer, or an unknown preset. - """ - if rotation_matrix is None: - return None, False, None - - if isinstance(rotation_matrix, str): - if rotation_matrix not in BUILTIN_ROTATION_PRESETS: - raise ValueError( - f"Unknown rotation_matrix preset '{rotation_matrix}'. " - f"Supported presets: {BUILTIN_ROTATION_PRESETS}." - ) - if rotation_matrix == "quarot_hadamard": - return None, True, "quarot_hadamard" - elif rotation_matrix == "hadamard": - return None, False, "hadamard" - else: # "random_hadamard" - return None, False, "random_hadamard" - - is_grouped = group_size is not None and group_size > 0 - if not is_grouped and not isinstance(rotation_matrix, dict): - raise ValueError( - "rotation_matrix (Tensor/dict) can only be used with a positive group_size. " - f"Got group_size={group_size}." - ) - - if isinstance(rotation_matrix, torch.Tensor): - assert ( - rotation_matrix.ndim == 2 and rotation_matrix.shape[0] == rotation_matrix.shape[1] - ), f"rotation_matrix must be square, got shape {rotation_matrix.shape}" - return {rotation_matrix.shape[0]: rotation_matrix}, False, None - - if isinstance(rotation_matrix, dict): - for k, t in rotation_matrix.items(): - assert ( - isinstance(t, torch.Tensor) and t.ndim == 2 and t.shape[0] == t.shape[1] - ), f"rotation_matrix[{k}] must be a square tensor, got shape {t.shape}" - return rotation_matrix, False, None - - raise TypeError( - f"rotation_matrix must be a Tensor, dict[int, Tensor], str, or None. " f"Got {type(rotation_matrix)}." - ) - - -def _get_custom_had(had_dict, size): - """Look up a custom Hadamard matrix for *size* from the normalized dict. - - Returns ``(had_tensor, True)`` if found, ``(None, False)`` otherwise. - """ - if had_dict is None: - return None, False - if size in had_dict: - return had_dict[size], True - return None, False - - -# --------------------------------------------------------------------------- -# Hook implementations -# --------------------------------------------------------------------------- - - -class FullOnlineHadamardHook(nn.Module): - """Pre-forward hook: full Hadamard on the entire last dimension (for ``down_proj``).""" - - def __init__(self, had_K, K, fp32_had=False, use_fast_had=True, had_matrix=None): - super().__init__() - self.custom_had = had_matrix is not None - if had_matrix is not None: - self.register_buffer("had_matrix", had_matrix) - self.had_K = None - self.K = None - else: - if had_K is not None: - self.register_buffer("had_K", had_K) - else: - self.had_K = None - self.K = K - self.fp32_had = fp32_had - self.use_fast_had = use_fast_had - - def __call__(self, module: nn.Module, args): - x = args[0] if isinstance(args, tuple) else args - x_dtype = x.dtype - - if self.custom_had: - H = self.had_matrix.to(device=x.device, dtype=x.dtype) - if self.fp32_had: - H = self.had_matrix.to(device=x.device).float() - x = (x.float() @ H.T).to(x_dtype) - else: - x = x @ H.T - elif self.fp32_had: - x = matmul_hadU_cuda(x.float(), self.had_K, self.K, use_fast_had=self.use_fast_had).to(x_dtype) - else: - x = matmul_hadU_cuda(x, self.had_K, self.K, use_fast_had=self.use_fast_had) - - if isinstance(args, tuple): - return (x,) + args[1:] - return x - - -class CrossHeadOnlineHadamardHook(nn.Module): - """Pre-forward hook: **cross-head** Hadamard on the ``num_heads`` dimension - (for ``o_proj``). - - After offline rotation: - - ``v_proj`` absorbed a per-head (within-head) Hadamard on ``head_dim``. - - ``o_proj`` absorbed a full Hadamard on ``hidden_size``. - - Since ``H_full = H_cross ⊗ H_within`` (Kronecker decomposition) and the - within-head part is already cancelled by ``v_proj`` through the attention - path (``H_within² = I``), the online hook only needs to apply the residual - **cross-head** Hadamard (``H_cross ⊗ I``): - - * reshape ``(*, hidden_size)`` → ``(*, num_heads, head_dim)`` - * transpose → ``(*, head_dim, num_heads)`` - * Hadamard on the **num_heads** axis (last dim) - * transpose back and reshape - """ - - def __init__(self, had_K, K, head_dim, fp32_had=False, use_fast_had=True, had_matrix=None): - """ - Args: - had_K: Hadamard sub-matrix from ``get_hadK(num_heads)``. - K: Block size from ``get_hadK(num_heads)``. - head_dim: ``hidden_size // num_attention_heads``. - fp32_had: Compute in fp32. - use_fast_had: If True use fast_hadamard_transform; if False use matmul_hadU. - had_matrix: Optional custom rotation matrix of shape ``(num_heads, num_heads)``. - """ - super().__init__() - self.custom_had = had_matrix is not None - if had_matrix is not None: - self.register_buffer("had_matrix", had_matrix) - self.had_K = None - self.K = None - else: - if had_K is not None: - self.register_buffer("had_K", had_K) - else: - self.had_K = None - self.K = K - self.had_dim = head_dim - self.fp32_had = fp32_had - self.use_fast_had = use_fast_had - - def __call__(self, module: nn.Module, args): - x = args[0] if isinstance(args, tuple) else args - x_dtype = x.dtype - - if self.fp32_had: - x = x.float() - - init_shape = x.shape - num_heads = init_shape[-1] // self.had_dim - - if self.custom_had: - H = self.had_matrix.to(device=x.device, dtype=x.dtype) - # reshape (*, hidden) → (*, num_heads, head_dim), transpose → (*, head_dim, num_heads) - x = x.reshape(-1, num_heads, self.had_dim).transpose(1, 2) - # apply H on last dim (num_heads): x @ H.T - x = (x @ H.T).transpose(1, 2) - elif self.use_fast_had and fast_hadamard_transform is not None and self.K == 1: - x = fast_hadamard_transform.hadamard_transform( - x.reshape(-1, num_heads, self.had_dim).transpose(1, 2), - scale=1 / math.sqrt(num_heads), - ).transpose(1, 2) - else: - # Fallback: use matmul_hadU (pure butterfly + had_K, no fast_hadamard_transform) - x = x.reshape(-1, num_heads, self.had_dim).transpose(1, 2) - x = matmul_hadU(x.contiguous()) - x = x.transpose(1, 2) - - if self.fp32_had: - x = x.to(x_dtype) - x = x.reshape(init_shape) - - if isinstance(args, tuple): - return (x,) + args[1:] - return x - - -# --------------------------------------------------------------------------- -# Registration helper -# --------------------------------------------------------------------------- - - -def register_online_had_hooks(model, mapping=None, fp32_had=False, use_fast_had=True): - """Register online Hadamard pre-forward hooks on ``down_proj`` and ``o_proj``. - - * **down_proj** (``online_full_had``): full Hadamard on ``intermediate_size``. - Compensates ``apply_exact_had_to_linear(down_proj, had_dim=-1, output=False)``. - - * **o_proj** (``online cross-head had``): cross-head Hadamard on ``num_heads``. - Compensates the residual after v_proj's within-head Hadamard cancels. - - Args: - model: A HuggingFace model whose weights have already been rotated. - mapping: A :class:`RotationMapping` (auto-inferred if ``None``). - fp32_had: Whether to compute the Hadamard transform in fp32. - use_fast_had: If True use fast_hadamard_transform; if False use matmul_hadU. - - Returns: - list of hook handles (call ``handle.remove()`` to detach). - """ - if mapping is None: - from auto_round.experimental.rotation_inplace.model_config import infer_mapping_from_model - - mapping = infer_mapping_from_model(model) - - config = model.config - num_heads = getattr(config, mapping.num_heads_attr) - hidden_size = getattr(config, mapping.hidden_size_attr) - intermediate_size = getattr(config, mapping.intermediate_size_attr) - head_dim = mapping.attn_head_dim or (hidden_size // num_heads) - - # down_proj: full Hadamard on intermediate_size - had_K_full, K_full = get_hadK(intermediate_size) - - # o_proj: cross-head Hadamard on num_heads - had_K_head, K_head = get_hadK(num_heads) - - # Identify target module suffixes from mapping - mlp_out_suffix = mapping.mlp_out.split(".")[-1] # e.g. "down_proj" - attn_o_suffix = mapping.attn_o.split(".")[-1] # e.g. "o_proj" - - handles = [] - for name, module in model.named_modules(): - if name.endswith(mlp_out_suffix) and isinstance(module, nn.Linear): - hook = FullOnlineHadamardHook( - had_K=had_K_full, - K=K_full, - fp32_had=fp32_had, - use_fast_had=use_fast_had, - ) - h = module.register_forward_pre_hook(hook) - handles.append(h) - elif name.endswith(attn_o_suffix) and isinstance(module, nn.Linear): - hook = CrossHeadOnlineHadamardHook( - had_K=had_K_head, - K=K_head, - head_dim=head_dim, - fp32_had=fp32_had, - use_fast_had=use_fast_had, - ) - h = module.register_forward_pre_hook(hook) - handles.append(h) - - return handles - - -def is_pow2(n): - return (n & (n - 1) == 0) and (n > 0) - - -# Adapted from https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py -def get_hadK(n: int, transpose=False) -> (torch.Tensor, int): - hadK, K = None, None - - if is_pow2(n): - K = 1 - return hadK, K - else: - from auto_round.experimental.transform.utils.hadamard import _fetch_hadamard_divisor - - hadK = _fetch_hadamard_divisor(n, torch.float, torch.device("cpu")) - if transpose: - hadK = hadK.T - if hadK is not None: - return hadK, 1 if is_pow2(hadK.shape[0]) else hadK.shape[0] - assert is_pow2(n) - - -def matmul_hadU(X, transpose=False): - n = X.shape[-1] - hadK, K = get_hadK(n, transpose) - input = X.clone().view(-1, n, 1) - output = input.clone() - while input.shape[1] > K: - input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) - output = output.view(input.shape) - output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] - output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] - output = output.view(input.shape[0], input.shape[1], -1) - input, output = (output, input) - del output - - if K > 1: - # Do not explicitly repeat - OOM - # input = torch.bmm( - # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) - # Use bcast instead - input = hadK.view(1, K, K).to(input) @ input - - return input.view(X.shape) / torch.tensor(n).sqrt() - - -def matmul_hadUt(X): - return matmul_hadU(X, transpose=True) - - -def random_hadamard_matrix(size, device): - # See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation" - Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) - Q = Q * 2 - 1 - Q = torch.diag(Q) - return matmul_hadU(Q).to(device) - - -def deterministic_hadamard_matrix(size, device): - """Build a deterministic Hadamard matrix of the given *size*. - - Applies the butterfly ``matmul_hadU`` to an identity matrix so that the - result is purely determined by ``get_hadK`` (no random sign flips). - """ - Q = torch.eye(size, dtype=torch.float64) - return matmul_hadU(Q).to(device) - - -def matmul_hadU_cuda(X, hadK, K, use_fast_had=True): - n = X.shape[-1] - if not use_fast_had or fast_hadamard_transform is None: - return matmul_hadU(X) - if K == 1: - return fast_hadamard_transform.hadamard_transform(X.contiguous(), 1.0 / torch.tensor(n).sqrt()) - # if transpose: - # hadK = hadK.T.contiguous() - input = X.view(*X.shape[:-1], K, n // K) - input = fast_hadamard_transform.hadamard_transform(input.contiguous(), 1.0 / torch.tensor(n).sqrt()) - input = hadK.to(input.device).to(input.dtype) @ input - return input.reshape(X.shape) - - -def matmul_hadUt_cuda(X, hadK, K, use_fast_had=True): - return matmul_hadU_cuda(X, hadK, K, use_fast_had=use_fast_had) - - -def apply_exact_had_to_linear( - module, had_dim=-1, output=False, use_fast_had=True, compute_device=None, had_matrix=None -): - """Apply Hadamard rotation to a Linear layer's weight in-place. - - Args: - module: ``nn.Linear`` layer. - had_dim: Dimension of each Hadamard block (``-1`` for full dimension). - output: If ``True`` rotate the output (row) side; otherwise input (col). - use_fast_had: Use ``fast_hadamard_transform`` when available. - compute_device: Device to run computation on. - had_matrix: Optional custom rotation matrix. When ``had_dim == -1`` - this should be a square tensor whose size equals - ``out_features`` (output) or ``in_features`` (input). When - ``had_dim > 0`` the size should equal ``had_dim``. - """ - assert isinstance(module, torch.nn.Linear) - in_features, out_features = module.in_features, module.out_features - - if had_dim != -1 and had_matrix is None: - assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!" - - W_ = module.weight.data - dtype = W_.dtype - dev = W_.device - init_shape = W_.shape - compute_dev = _resolve_compute_device(compute_device) - W_ = W_.double().to(compute_dev) - - if had_matrix is not None: - H = had_matrix.to(device=compute_dev, dtype=torch.float64) - if had_dim == -1: - # Full-dimension custom matrix - if output: - # W.T = H @ W.T → W = (H @ W.T).T = W @ H.T - W_ = W_ @ H.T - else: - # W = H @ W (rotate input columns: W_new[i,:] = sum H[i,j]*W[j,:]) - # Actually for input side: W_new = W @ H (each row is rotated) - W_ = W_ @ H.T - else: - # Per-block custom matrix - if output: - W_ = W_.t() - transposed_shape = W_.shape - flat = W_.reshape(-1, had_dim) - W_ = (flat @ H.T).reshape(transposed_shape).t() - else: - flat = W_.reshape(-1, had_dim) - W_ = (flat @ H.T).reshape(init_shape) - elif had_dim == -1: - if output: - had_K, K = get_hadK(out_features) - W_ = matmul_hadU_cuda(W_.t(), had_K, K, use_fast_had=use_fast_had).t() - if not output: - had_K, K = get_hadK(in_features) - W_ = matmul_hadU_cuda(W_, had_K, K, use_fast_had=use_fast_had) - else: - # Apply Hadamard to the last had_dim chunks of the weights - if output: - W_ = W_.t() - transposed_shape = W_.shape - if use_fast_had and fast_hadamard_transform is not None: - W_ = ( - fast_hadamard_transform.hadamard_transform( - W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim), scale=1 / math.sqrt(had_dim) - ) - .reshape(transposed_shape) - .t() - ) - else: - W_ = matmul_hadU(W_.reshape(-1, had_dim)).reshape(transposed_shape).t() - else: - if use_fast_had and fast_hadamard_transform is not None: - n = W_.shape[1] - W_ = fast_hadamard_transform.hadamard_transform( - W_.reshape(-1, n // had_dim, had_dim), scale=1 / math.sqrt(had_dim) - ).reshape(init_shape) - else: - W_ = matmul_hadU(W_.reshape(-1, had_dim)).reshape(init_shape) - module.weight.data = W_.to(device=dev, dtype=dtype) - - -def apply_cross_head_had_to_linear( - module, num_heads, head_dim, use_fast_had=True, compute_device=None, had_matrix=None -): - """Apply a cross-head Hadamard rotation to a Linear layer's input side. - - The operation is equivalent to ``(H_cross ⊗ I_head_dim)`` applied to the - input columns: - - * Reshape columns ``(hidden_size,)`` → ``(num_heads, head_dim)`` - * Transpose → ``(head_dim, num_heads)`` - * Hadamard on the ``num_heads`` axis - * Transpose back and reshape - - This mirrors what :class:`CrossHeadOnlineHadamardHook` does at runtime. - - Args: - module: ``nn.Linear`` layer whose ``in_features == num_heads * head_dim``. - num_heads: Number of attention heads. - head_dim: Per-head dimension. - use_fast_had: Use ``fast_hadamard_transform`` when available. - compute_device: Device to run computation on. - had_matrix: Optional custom rotation matrix of shape ``(num_heads, num_heads)``. - """ - assert isinstance(module, torch.nn.Linear) - W_ = module.weight.data - dtype = W_.dtype - dev = W_.device - compute_dev = _resolve_compute_device(compute_device) - W_ = W_.double().to(compute_dev) - - out_f = W_.shape[0] - # W shape: (out_features, hidden_size) where hidden_size = num_heads * head_dim - # Reshape columns: (out_f, num_heads, head_dim) - W_ = W_.reshape(out_f, num_heads, head_dim) - # Transpose last two dims: (out_f, head_dim, num_heads) - W_ = W_.transpose(1, 2).contiguous() - - if had_matrix is not None: - H = had_matrix.to(device=compute_dev, dtype=torch.float64) - # Apply H on last dim (num_heads): flat @ H.T - flat = W_.reshape(-1, num_heads) - W_ = (flat @ H.T).reshape(out_f, head_dim, num_heads) - elif use_fast_had and fast_hadamard_transform is not None and is_pow2(num_heads): - W_ = fast_hadamard_transform.hadamard_transform(W_, scale=1.0 / math.sqrt(num_heads)) - else: - W_ = matmul_hadU(W_.reshape(-1, num_heads)).reshape(out_f, head_dim, num_heads) - - # Transpose back: (out_f, num_heads, head_dim) → (out_f, hidden_size) - W_ = W_.transpose(1, 2).contiguous().reshape(out_f, num_heads * head_dim) - module.weight.data = W_.to(device=dev, dtype=dtype) - - -# --------------------------------------------------------------------------- -# Grouped (block-diagonal) Hadamard utilities -# --------------------------------------------------------------------------- - - -class OnlineHadamardPostHook(nn.Module): - """Forward hook (post-hook) adapter: wraps a pre-hook-style Hadamard - transform to apply it on the layer's **output** instead of input. - - Used for v_proj per-head Hadamard on the output side when online - rotation is not fused into weights. - """ - - def __init__(self, pre_hook): - super().__init__() - self.pre_hook = pre_hook - - def __call__(self, module, input, output): - result = self.pre_hook(module, (output,)) - if isinstance(result, tuple): - return result[0] - return result - - -class GroupOnlineHadamardHook(nn.Module): - """Pre-forward hook: block-diagonal Hadamard with fixed ``group_size`` on last dim. - - Reshapes ``(*, D)`` → ``(*, D // group_size, group_size)``, applies Hadamard - per group, then reshapes back. Much cheaper than a full-dimension Hadamard. - """ - - def __init__(self, group_size, fp32_had=False, use_fast_had=True, had_matrix=None): - super().__init__() - self.group_size = group_size - self.fp32_had = fp32_had - self.use_fast_had = use_fast_had - self.custom_had = had_matrix is not None - - if had_matrix is not None: - self.register_buffer("had_matrix", had_matrix) - self.had_K = None - self.K = None - elif not is_pow2(group_size): - had_K, K = get_hadK(group_size) - if had_K is not None: - self.register_buffer("had_K", had_K) - else: - self.had_K = None - self.K = K - else: - self.had_K = None - self.K = 1 - - def __call__(self, module: nn.Module, args): - x = args[0] if isinstance(args, tuple) else args - x_dtype = x.dtype - init_shape = x.shape - gs = self.group_size - - if self.fp32_had: - x = x.float() - - # Reshape: (*, D) → (*, D//gs, gs) - x = x.reshape(*init_shape[:-1], init_shape[-1] // gs, gs) - - if self.custom_had: - H = self.had_matrix.to(device=x.device, dtype=x.dtype) - flat = x.reshape(-1, gs) - x = (flat @ H.T).reshape(*init_shape[:-1], init_shape[-1] // gs, gs) - elif self.use_fast_had and fast_hadamard_transform is not None and self.K == 1: - x = fast_hadamard_transform.hadamard_transform(x, scale=1.0 / math.sqrt(gs)) - else: - x = x.reshape(-1, gs) - x = matmul_hadU(x) - x = x.reshape(*init_shape[:-1], init_shape[-1] // gs, gs) - - x = x.reshape(init_shape) - - if self.fp32_had: - x = x.to(x_dtype) - - if isinstance(args, tuple): - return (x,) + args[1:] - return x - - -def _apply_grouped_had_to_weight(W, group_size, side="input", use_fast_had=True, had_matrix=None): - """Apply block-diagonal Hadamard to a weight matrix. - - Args: - W: Weight tensor, shape (out_features, in_features). - group_size: Block size for the Hadamard rotation. - side: ``'input'`` rotates columns (in_features dim), - ``'output'`` rotates rows (out_features dim). - use_fast_had: Use fast_hadamard_transform if available. - had_matrix: Optional custom Hadamard matrix of shape ``(gs, gs)`` - to use instead of the built-in Hadamard. - - Returns: - Rotated weight tensor. - """ - gs = group_size - dtype = W.dtype - W = W.double() - - def _had_on_last_dim(X): - """Apply Hadamard on the last dimension (size gs) of X shaped (..., gs).""" - if had_matrix is not None: - H = had_matrix.to(device=X.device, dtype=X.dtype) - # X: (..., gs) → batch matmul with H^T → X @ H^T - flat = X.reshape(-1, gs) - return (flat @ H.T).reshape(X.shape) - if use_fast_had and fast_hadamard_transform is not None and is_pow2(gs): - return fast_hadamard_transform.hadamard_transform(X, scale=1.0 / math.sqrt(gs)) - orig_shape = X.shape - return matmul_hadU(X.reshape(-1, gs)).reshape(orig_shape) - - if side == "input": - out_f, in_f = W.shape - W = W.reshape(out_f, in_f // gs, gs) - W = _had_on_last_dim(W) - W = W.reshape(out_f, in_f) - else: - out_f, in_f = W.shape - Wt = W.t().contiguous() - Wt = Wt.reshape(in_f, out_f // gs, gs) - Wt = _had_on_last_dim(Wt) - W = Wt.reshape(in_f, out_f).t().contiguous() - - return W.to(dtype) - - -def _rotate_linear_grouped(module, group_size, side="input", use_fast_had=True, compute_device=None, had_matrix=None): - """Apply block-diagonal Hadamard rotation to a Linear layer's weight. - - Args: - module: ``nn.Linear`` layer. - group_size: Block size. - side: ``'input'`` or ``'output'``. - use_fast_had: Use fast_hadamard_transform. - compute_device: Device to run computation on. If None, auto-detects GPU. - had_matrix: Optional custom Hadamard matrix of shape ``(gs, gs)``. - """ - dtype = module.weight.data.dtype - dev = module.weight.data.device - compute_dev = _resolve_compute_device(compute_device) - W = module.weight.data.to(device=compute_dev, dtype=torch.float64) - W = _apply_grouped_had_to_weight(W, group_size, side=side, use_fast_had=use_fast_had, had_matrix=had_matrix) - module.weight.data = W.to(device=dev, dtype=dtype) - - if side == "output" and module.bias is not None: - bias = module.bias.data.to(device=compute_dev, dtype=torch.float64) - gs = group_size - bias = bias.reshape(-1, gs) - if had_matrix is not None: - H = had_matrix.to(device=compute_dev, dtype=torch.float64) - bias = (bias @ H.T).reshape(-1) - elif use_fast_had and fast_hadamard_transform is not None and is_pow2(gs): - bias = ( - fast_hadamard_transform.hadamard_transform(bias.unsqueeze(0), scale=1.0 / math.sqrt(gs)) - .squeeze(0) - .reshape(-1) - ) - else: - bias = matmul_hadU(bias).reshape(-1) - module.bias.data = bias.to(device=dev, dtype=dtype) - - -def _rotate_embedding_grouped(embedding, group_size, use_fast_had=True, compute_device=None, had_matrix=None): - """Apply block-diagonal Hadamard rotation to an Embedding layer. - - Embedding weight: (vocab, hidden_size) → rotate on hidden_size (columns). - """ - dtype = embedding.weight.data.dtype - dev = embedding.weight.data.device - compute_dev = _resolve_compute_device(compute_device) - W = embedding.weight.data.to(device=compute_dev, dtype=torch.float64) - W = _apply_grouped_had_to_weight(W, group_size, side="input", use_fast_had=use_fast_had, had_matrix=had_matrix) - new_W = W.to(device=dev, dtype=dtype) - del W - embedding.weight.data = new_W - - -def register_online_had_hooks_grouped(model, mapping, group_size, fp32_had=False, use_fast_had=True): - """Register per-group online Hadamard hooks on ``down_proj`` and ``o_proj``. - - In grouped mode: - - **down_proj**: block-diagonal Hadamard on ``intermediate_size`` with ``group_size``. - - **o_proj**: block-diagonal Hadamard on ``hidden_size`` with ``group_size``. - - Args: - model: HuggingFace model with rotated weights. - mapping: RotationMapping. - group_size: Block size for block-diagonal Hadamard. - fp32_had: Compute in fp32. - use_fast_had: Use fast_hadamard_transform. - - Returns: - list of hook handles. - """ - mlp_out_suffix = mapping.mlp_out.split(".")[-1] - attn_o_suffix = mapping.attn_o.split(".")[-1] - - handles = [] - for name, module in model.named_modules(): - if name.endswith(mlp_out_suffix) and isinstance(module, nn.Linear): - hook = GroupOnlineHadamardHook( - group_size=group_size, - fp32_had=fp32_had, - use_fast_had=use_fast_had, - ) - h = module.register_forward_pre_hook(hook) - handles.append(h) - elif name.endswith(attn_o_suffix) and isinstance(module, nn.Linear): - hook = GroupOnlineHadamardHook( - group_size=group_size, - fp32_had=fp32_had, - use_fast_had=use_fast_had, - ) - h = module.register_forward_pre_hook(hook) - handles.append(h) - - return handles +from auto_round.algorithms.transforms.rotation.inplace.hooks import * # noqa: F401, F403 diff --git a/auto_round/experimental/transform/apply.py b/auto_round/experimental/transform/apply.py index aafd4c1b0..e0aa77e13 100644 --- a/auto_round/experimental/transform/apply.py +++ b/auto_round/experimental/transform/apply.py @@ -1,196 +1,13 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -import torch -import tqdm +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.apply`. +""" -from auto_round.experimental.qmodules.base import QModuleBase -from auto_round.experimental.transform.hadamards import build_hadamard_transform -from auto_round.experimental.transform.rotation_config import RotationConfig -from auto_round.experimental.utils import is_triton_kernel_available, normalize_rotation_config +from auto_round.algorithms.transforms.rotation.apply import ( # noqa: F401 + apply_rotation_transform, +) __all__ = ["apply_rotation_transform"] - - -def apply_rotation_transform( - model: torch.nn.Module, - config: str | dict | RotationConfig | None, - location: str = "weight", - use_tqdm=True, - desc=None, - data_type="mx_fp", -): - """ - Apply a transform configuration to a model. - - Weight and activation transforms are attached as submodules and are - triggered via PyTorch hooks. - - :param model: Model to which the transform configuration will be applied. - :param config: Transform configuration to apply. Supported values are: - * ``str``: A named/preset transform configuration. In this case, - resolved to a concrete quantization/transform configuration. - * ``dict``: A raw configuration mapping that will be normalized - (via :func:`normalize_rotation_config`) and then passed to - :class:`TransformConfig`. - * :class:`TransformConfig`: An existing configuration instance. - This will be used to construct the final configuration after - normalization. - * ``None``: Uses the default behavior of - :func:`_normalize_rotation_config` (for example, inferring a - configuration from ``data_type`` or other project defaults), if - supported. - :param data_type: quantization data type. - :param use_tqdm: If ``True``, wrap the per-module application in a - tqdm progress bar. - :param desc: Optional description string to show in the tqdm progress - bar. If ``None``, a description will be derived from - ``config.transform_type``. - """ - - config = normalize_rotation_config(config, data_type) - if not isinstance(config, RotationConfig): - config = RotationConfig(**config) - - modules_config = [ - (name, module, config) - for name, module in model.named_modules() - if isinstance(module, torch.nn.Linear) or isinstance(module, QModuleBase) - ] - - desc = f"Applying {config.hadamard_type} transforms" if desc is None else desc - for name, module, config in tqdm.tqdm(modules_config, desc=desc, disable=(not use_tqdm)): - if "lm_head" in name: - continue - _apply_to_module(model, module, config, location, data_type) - - # attach config to model for compression/serialization. Use a plain dict so - # that downstream HF `save_pretrained` -> JSON works (RotationConfig is a - # pydantic model and is not directly JSON serializable). - setattr(model, "rotation_config", config.model_dump() if hasattr(config, "model_dump") else config) - hooks = None - - return model, hooks - - -def _apply_to_module( - model: torch.nn.Module, - module: torch.nn.Module, - config: RotationConfig, - location: str = "weight", - data_type: str = "mx_fp", -): - """ - Create transforms and apply them to the module - - :param model: model which module belongs to - :param module: target module to apply transforms to - """ - - # create transform as submodule - hadamard_name = config.hadamard_type - - if location == "input": - - # activation needs transpose - input_hadamard_transform = build_hadamard_transform( - **config.model_dump(), - location="input", - inverse=True, - device="cpu", - precision=module.dtype, # for online activation, the transform dtype maybe bfloat16/float16. - ) - - if config.hadamard_type != "random_hadamard": - hadamard_weight = input_hadamard_transform.weight - else: - hadamard_weight = None - - if is_triton_kernel_available(data_type): - from auto_round.experimental.transform.triton.mxfp4 import mxfp4_forward_kernel_wrapper - - def input_hook(self, args): - input = args[0] - # transform(input) - orig_shape = input.shape - orig_dtype = input.dtype - x_flat = input.contiguous().flatten(end_dim=-2) - qdq_input, _ = mxfp4_forward_kernel_wrapper( - x_flat, - ( - hadamard_weight.to(orig_dtype) - if hadamard_weight is not None - else self.hadamard_matrix.T.to(orig_dtype) - ), # this matrix from w_transform, needs transpose - ) - return qdq_input.reshape(orig_shape).to(orig_dtype) - - # for fused transform + quantization kernel - module.pre_dequantized_input = True - module.register_forward_pre_hook(input_hook, prepend=True) - else: - - from auto_round.experimental.transform.utils.matrix import _multihead_matmul - - def input_hook(self, args): - input = args[0] - - ori_shape = input.shape - orig_dtype = input.dtype - - if hadamard_weight is not None: - input = input.view(-1, hadamard_weight.shape[0]) - return ( - (_multihead_matmul(input, hadamard_weight.to(input.device).to(orig_dtype))) - .view(ori_shape) - .to(orig_dtype) - ) - else: - input = input.view(-1, self.hadamard_matrix.shape[0]) - return ( - (_multihead_matmul(input, self.hadamard_matrix.T.to(orig_dtype))).view(ori_shape).to(orig_dtype) - ) - - # for fused transform + quantization kernel - module.pre_dequantized_input = False - module.register_forward_pre_hook(input_hook, prepend=True) - - elif location == "weight": - # eagerly apply transformation to weight - # fuse transform into weight - assert hasattr(module, "weight") - - weight_hadamard_transform = build_hadamard_transform( - **config.model_dump(), - location="weight", - device=module.weight.device, - ) - - # need save random hadamard matrix needed when inference - if config.hadamard_type == "random_hadamard": - # for saving transform weight - from auto_round.experimental.transform.patch_modules import patch_quantlinear - - patch_quantlinear(weight_hadamard_transform) - - # for autoround tuning: weight not tuning - # for rtn: weight transformed before saving - from auto_round.experimental.transform.patch_modules import ( - patch_wrapperlinear_to_apply_transform, - patch_wrapperwalayer_forward_to_apply_transform, - ) - - input_hadamard_transform = build_hadamard_transform( - **config.model_dump(), - location="input", - inverse=True, - device=module.weight.device, - precision=module.weight.dtype, # for online activation, the transform dtype maybe bfloat16/float16. - ) - - patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform) - patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform) - - else: - # TODO: apply transform to output/q/k - raise NotImplementedError() diff --git a/auto_round/experimental/transform/hadamards.py b/auto_round/experimental/transform/hadamards.py index 42c47818d..83efa91b6 100644 --- a/auto_round/experimental/transform/hadamards.py +++ b/auto_round/experimental/transform/hadamards.py @@ -1,143 +1,24 @@ -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from typing import Any, Callable, Dict - -import torch -import torch.nn as nn - -from auto_round.experimental.transform.utils.hadamard import deterministic_hadamard_matrix, random_hadamard_matrix -from auto_round.experimental.transform.utils.matrix import apply_transform_weight - - -def filter_kwarg_dict(fn_or_method: Callable, kwarg_dict: Dict[str, Any]) -> Dict[str, Any]: - fn_or_method_keys = inspect.signature(fn_or_method).parameters.keys() - return {k: v for k, v in kwarg_dict.items() if k in fn_or_method_keys} - - -class HadamardTransform(nn.Module): - - def __init__( - self, - block_size: int = 32, - device: torch.device = None, - precision: torch.dtype = torch.float32, - location: str = "weight", - module_type: type[torch.nn.Module] = torch.nn.Linear, - inverse: bool = False, - ): - """Initialize a Hadamard transform module. - - Args: - block_size: Size of each Hadamard block. The input tensor is reshaped - to ``(-1, block_size)`` before applying the transform. - device: Device on which to create the Hadamard matrix. - precision: Data type used for the Hadamard matrix weights, using float32 as default. - location: Target location used by ``apply_transform_weight`` when - applying the transform. - module_type: Module type associated with the transform application, - typically ``torch.nn.Linear``. - inverse: Whether to build the inverse form of the transform. - """ - - super().__init__() - self.size = block_size - self.scale = 1 / math.sqrt(self.size) - self.location = location - self.module_type = module_type - self.inverse = inverse - self.weight = self._create_weight(self.size, device, precision) - - def _create_weight( - self, - size: int, - device: torch.device = None, - precision: torch.dtype = torch.float32, - ) -> torch.nn.Parameter: - data = deterministic_hadamard_matrix(size, precision, device) * self.scale - # TODO: implement SpinQuant, which rotation matrix is learnable - return nn.Parameter(data, requires_grad=False) - - def forward(self, x: torch.Tensor): - # Hadamard transform is it own inverse - ori_shape = x.shape - x = x.view(-1, self.size) - return ( - ( - apply_transform_weight( - self.weight.to(x.device), - x.to(dtype=self.weight.dtype), - self.location, - self.module_type, - ) - ) - .to(x.dtype) - .view(ori_shape) - ) - - -class RandomHadamardTransform(HadamardTransform): - def __init__( - self, - block_size: int = 32, - device: torch.device = None, - precision: torch.dtype = None, - location: str = "weight", - module_type: type[torch.nn.Module] = torch.nn.Linear, - inverse: bool = False, - seed: int | None = None, - generator: torch.Generator | None = None, - ): - if generator is not None: - self.generator = generator - else: - self.generator = torch.Generator() - if seed is not None: - self.generator.manual_seed(seed) - - super().__init__( - block_size=block_size, - device=device, - precision=precision, - location=location, - module_type=module_type, - inverse=inverse, - ) - - def _create_weight( - self, - size: int, - device: torch.device = None, - precision: torch.dtype = None, - ) -> torch.nn.Parameter: - data = random_hadamard_matrix(size, precision, device, self.generator) * self.scale - # activation needs transpose - if self.inverse: - data = data.T - # data = deterministic_hadamard_matrix(size, precision, device) * self.scale - # TODO: implement SpinQuant, which rotation matrix is learnable - return nn.Parameter(data, requires_grad=False) - - -HADAMARDS = { - "hadamard": HadamardTransform, - "random_hadamard": RandomHadamardTransform, -} - - -def build_hadamard_transform(hadamard_type: str, **hadamard_kwargs): - hadamard = HADAMARDS[hadamard_type] - return hadamard(**filter_kwarg_dict(hadamard.__init__, hadamard_kwargs)) +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. + +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.transforms`. +""" + +from auto_round.algorithms.transforms.rotation.transforms import ( + HADAMARDS, + HadamardTransform, + RandomHadamardTransform, +) +from auto_round.algorithms.transforms.rotation.transforms import _filter_kwargs as filter_kwarg_dict # noqa: F401 +from auto_round.algorithms.transforms.rotation.transforms import ( + build_hadamard_transform, +) + +__all__ = [ + "HADAMARDS", + "HadamardTransform", + "RandomHadamardTransform", + "build_hadamard_transform", +] diff --git a/auto_round/experimental/transform/patch_modules.py b/auto_round/experimental/transform/patch_modules.py index e0f9adbae..980c28588 100644 --- a/auto_round/experimental/transform/patch_modules.py +++ b/auto_round/experimental/transform/patch_modules.py @@ -1,172 +1,19 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 - -import torch -import transformers - -from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear, pack_fp4_to_uint8 -from auto_round.wrapper import WrapperLinear, WrapperWALayer - - -def patch_wrapperlinear_to_apply_transform(w_transform, inp_transform): - """ - Globally monkey-patch WrapperLinear._qdq_weight and WrapperLinear._qdq_act so that it applies - a weight and activation transform before quantization. - - e.g. by apply_transform() before wrapper_block(). - """ - - if getattr(WrapperLinear, "_hadamard_patched", False): - return - - orig_qdq_weight = WrapperLinear._qdq_weight - - def _qdq_weight_patched(self, value, min_scale, max_scale): - """ - # If no transform attached, fall back to original behavior - if not hasattr(self.orig_layer, transform_attr): - return orig_qdq_weight(self, value, min_scale, max_scale) - """ - - if self.orig_layer.bits >= 16: - # keep original behavior for >=16bit to avoid changing semantics unexpectedly - return orig_qdq_weight(self, value, min_scale, max_scale) - - if getattr(self, "applied_weight_hadamard", None) is None: - with torch.no_grad(): - weight = self.orig_layer.weight - if weight.device.type == "meta": - weight = self.orig_layer.get_weight().to(self.device) - - is_conv1d = type(self.orig_layer) == transformers.pytorch_utils.Conv1D - if is_conv1d: - weight = weight.t().continuous() - new_weight = w_transform(weight).to(self.device) - if is_conv1d: - new_weight = weight.t().continuous() - self.orig_layer.weight.data.copy_(new_weight) - self.applied_weight_hadamard = True - - return orig_qdq_weight(self, value, min_scale, max_scale) - - orig_qdq_act = WrapperLinear._qdq_act - - def _qdq_act_patched(self, x, act_min_scale, act_max_scale, act_max=None): - - x = inp_transform(x) - - return orig_qdq_act(self, x, act_min_scale, act_max_scale, act_max) - - WrapperLinear._qdq_weight = _qdq_weight_patched - WrapperLinear._qdq_act = _qdq_act_patched - WrapperLinear._hadamard_patched = True - - -def patch_wrapperwalayer_forward_to_apply_transform(inp_transform): - """ - Globally monkey-patch WrapperWALayer.forward so that it applies - a activation transform before quantization. - - e.g. by apply_transform() before wrapper_block(). - """ - - if getattr(WrapperWALayer, "_hadamard_forward_patched", False): - return - - orig_forward = WrapperWALayer.forward - - def _forward_patched(self, x): - """ - # If no transform attached, fall back to original behavior - if not hasattr(self.orig_layer, transform_attr): - return orig_forward(self, x) - """ - - act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None - - # transform = getattr(self.orig_layer, transform_attr) - x = inp_transform(x) - - x, _, _ = self.orig_layer.act_quant_func( - x, - bits=self.orig_layer.act_bits, - group_size=self.orig_layer.act_group_size, - scale_dtype=self.orig_layer.scale_dtype, - q_scale_thresh=self.orig_layer.q_scale_thresh, - data_type=self.orig_layer.act_data_type, - tensor_max=act_max, - ) - return self.orig_layer.forward(x) - - WrapperWALayer.forward = _forward_patched - WrapperWALayer._hadamard_forward_patched = True - - -def patch_quantlinear(w_transform): - """ """ - - if getattr(QuantLinear, "_pack_patched", False): - return - - from auto_round.data_type.nvfp import cast_to_fp4, get_reciprocal - from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad - from auto_round.utils import get_packing_device - - E8M0_EXPONENT_BIAS = 127 - E8M0_EXPONENT_NAN_VAL = 255 - - def _pack_patched( - self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_global_scale=None, device=None - ): - device = get_packing_device(device) - if getattr(linear, "bias", None) is not None: - self.bias = linear.bias.detach().to(torch.float16) - - W = linear.weight.data.detach().to(device) - if type(linear) == torch.nn.Conv2d: - W = W.flatten(1) - if type(linear) == transformers.pytorch_utils.Conv1D: - W = W.t() - - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(W, self.group_size) - scales = scales.to(device) - if self.is_nv: - assert global_scale is not None and global_scale.numel() == 1 - global_scale = global_scale.reshape([1]) - global_scale = global_scale.to(device) - scaled_tensor = tensor.to(global_scale.dtype) * get_reciprocal( - scales.reshape(tensor.shape[0], -1) * get_reciprocal(global_scale) - ) - scaled_tensor.clamp_(-6.0, 6.0) - scaled_tensor = cast_to_fp4(scaled_tensor) - else: - scaled_tensor = tensor / (2 ** scales.reshape(tensor.shape[0], -1)) - scaled_tensor = revert_tensor_by_pad(scaled_tensor, orig_shape=orig_shape, pad_len=pad_len) - if self.is_mx: - final_scale = (scales + E8M0_EXPONENT_BIAS).clamp(0, E8M0_EXPONENT_NAN_VAL).to(torch.uint8) - else: - final_scale = scales.to(torch.float8_e4m3fn) - - self.weight_scale = final_scale - # self.weight = get_compressed_weight(scaled_tensor, self.bits, self.data_type) ## TODO - if self.bits == 8: - compress_dtype = torch.float8_e4m3fn - self.weight = scaled_tensor.to(compress_dtype) - - else: - compress_dtype = torch.uint8 - self.weight_packed = pack_fp4_to_uint8(scaled_tensor) - - if global_scale is not None: - self.weight_global_scale = global_scale.to(torch.float32).to(device) - - if input_global_scale is not None: - # TODO: the shape of `input_global_scale` is [] in some cases — need to investigate why. - self.input_global_scale = input_global_scale.to(torch.float32).to(device).reshape([1]) - - # add transform weight - self.register_buffer("hadamard_matrix", w_transform.weight.to(device)) - return - - QuantLinear.pack = _pack_patched - QuantLinear._pack_patched = True +"""Backward-compat re-export shim. + +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.patch`. +""" + +from auto_round.algorithms.transforms.rotation.patch import ( # noqa: F401 + patch_quantlinear, + patch_wrapperlinear_to_apply_transform, + patch_wrapperwalayer_forward_to_apply_transform, +) + +__all__ = [ + "patch_quantlinear", + "patch_wrapperlinear_to_apply_transform", + "patch_wrapperwalayer_forward_to_apply_transform", +] diff --git a/auto_round/experimental/transform/rotation_config.py b/auto_round/experimental/transform/rotation_config.py index dfcfb5d45..fad2ede62 100644 --- a/auto_round/experimental/transform/rotation_config.py +++ b/auto_round/experimental/transform/rotation_config.py @@ -1,61 +1,11 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -from typing import Optional +The canonical ``RotationConfig`` schema now lives in +:mod:`auto_round.algorithms.transforms.rotation.config`. +""" -from pydantic import BaseModel, Field, field_validator +from auto_round.algorithms.transforms.rotation.config import RotationConfig # noqa: F401 __all__ = ["RotationConfig"] - - -class RotationConfig(BaseModel): - """ - Unified configuration for Hadamard rotation/transform applied to a model. - - Two implementation paths are supported: - - * ``backend="inplace"`` -> ``auto_round.experimental.rotation_inplace`` - QuaRot-style residual-stream / per-layer rotation. Supports any - weight/activation dtype (incl. INT4/INT8/FPx). Can optionally fuse - the online Hadamard into weights (``fuse_online_to_weight=True``). - * ``backend="transform"`` -> ``auto_round.experimental.transform`` - Per-Linear weight + activation Hadamard with a fused triton kernel. - **Only supports MXFP4 / NVFP4** and **cannot fuse online to weight.** - * ``backend="auto"`` (default) - - If ``fuse_online_to_weight=True`` -> inplace (fused). - - Else if ``data_type`` is MX-FP / NV-FP -> transform. - - Otherwise -> inplace (unfused). - - Notes: - * ``block_size`` is the group/block size for grouped Hadamard. - For ``backend="inplace"`` it is forwarded as ``group_size`` (``None`` - / ``-1`` means full-dimension Hadamard). - """ - - # ---- shared ---- - backend: str = Field(default="auto") - block_size: Optional[int] = Field(default=None) - hadamard_type: str = Field(default="hadamard") - - # ---- inplace-only ---- - fuse_online_to_weight: Optional[bool] = Field(default=None) - allow_online_rotation: bool = Field(default=True) - - # for random hadamard transform (transform path) - random_seed: bool = Field(default=False, exclude=True) - - @field_validator("backend") - @classmethod - def validate_backend(cls, v: str) -> str: - allowed = {"auto", "inplace", "transform"} - if v not in allowed: - raise ValueError(f"Unsupported backend: {v}. Supported values: {sorted(allowed)}") - return v - - @field_validator("hadamard_type") - @classmethod - def validate_hadamard_type(cls, v: str) -> str: - allowed = {"hadamard", "random_hadamard", "quarot_hadamard"} - if v not in allowed: - raise ValueError(f"Unsupported hadamard_type: {v}. Supported values: {sorted(allowed)}") - return v diff --git a/auto_round/experimental/transform/triton/mxfp4.py b/auto_round/experimental/transform/triton/mxfp4.py index 8028c167b..3cc26a7d0 100644 --- a/auto_round/experimental/transform/triton/mxfp4.py +++ b/auto_round/experimental/transform/triton/mxfp4.py @@ -1,197 +1,14 @@ -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -# Refer code here: -# https://github.com/IST-DASLab/FP-Quant/blob/master/inference_lib/src/fp_quant/module/triton/mxfp4.py +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.utils.triton.mxfp4`. +""" -import torch -import triton # pylint: disable=E0401 -import triton.language as tl # pylint: disable=E0401 - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 32 * 32}), - triton.Config({"BLOCK_SIZE": 64 * 32}), - triton.Config({"BLOCK_SIZE": 128 * 32}), - triton.Config({"BLOCK_SIZE": 256 * 32}), - triton.Config({"BLOCK_SIZE": 512 * 32}), - ], - key=[], +from auto_round.algorithms.transforms.rotation.utils.triton.mxfp4 import ( # noqa: F401 + mxfp4_forward_kernel, + mxfp4_forward_kernel_wrapper, ) -@triton.jit -def mxfp4_forward_kernel( - x_ptr, - hadamard_matrix_ptr, - output_ptr, - clip_mask_ptr, - n_elements: tl.constexpr, - hadamard_dim: tl.constexpr, - group_size: tl.constexpr, - gaussian_scale: tl.constexpr, - quest: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim) - hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim) - - # load x - pid = tl.program_id(0) - start_idx = pid * BLOCK_SIZE - offsets = start_idx + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x_flat = tl.load(x_ptr + offsets, mask=mask) - - # hadamard transform - x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim)) - x_had = tl.dot(x, hadamard_matrix) - - # group - x_had_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size)) - - # scale - # quest=True: per-group Gaussian-based scale = gaussian_scale * std - # quest=False: per-group max-abs-based scale, adjusted to FP4 range - if quest: - mean_squared = tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size - mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size - std = tl.sqrt(mean_squared - mean * mean) - scales = gaussian_scale * std + 1e-8 - shared_exps = tl.exp2(tl.floor(tl.log2(scales))) - x_had_scaled = x_had_grouped / shared_exps - else: - scales = tl.max(tl.abs(x_had_grouped), axis=-1, keep_dims=True) - shared_exps = tl.exp2(tl.floor(tl.log2(scales)) - 2) / (3 / 4) - x_had_scaled = x_had_grouped / shared_exps - - # quantize - # Map abs(x) to FP4 levels {0, 0.5, 1, 1.5, 2, 3, 4, 6} - x_had_scaled_abs = tl.abs(x_had_scaled) - x_had_scaled_sign = tl.where( - x_had_scaled > 0, - 1, - -1, - ) - - x_fp4 = ( - tl.where( - x_had_scaled_abs > 5, - 6, - tl.where( - x_had_scaled_abs > 3.5, - 4, - tl.where( - x_had_scaled_abs > 2.5, - 3, - tl.where( - x_had_scaled_abs > 1.75, - 2, - tl.where( - x_had_scaled_abs > 1.25, - 1.5, - tl.where( - x_had_scaled_abs > 0.75, - 1, - tl.where( - x_had_scaled_abs > 0.25, - 0.5, - 0, - ), - ), - ), - ), - ), - ), - ) - * x_had_scaled_sign - ) - if clip_mask_ptr is not None: - tl.store( - clip_mask_ptr + offsets, - tl.reshape(x_had_scaled_abs < 6, (BLOCK_SIZE,)), - mask=mask, - ) - - # dequantize - x_dequantized = x_fp4 * shared_exps - - # Reshape back to flat form for storage - x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,)) - - # store - tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask) - - -@torch.compiler.disable() -def mxfp4_forward_kernel_wrapper( - x, - hadamard_matrix, - return_clip_mask=False, - quest=False, - gaussian_scale=3 / 4, -): - """ - Refer code here: - https://github.com/IST-DASLab/FP-Quant/blob/master/inference_lib/src/fp_quant/module/triton/mxfp4.py - Apply Hadamard transform + group-wise FP4 quantize/dequantize on x. - - Note: - The output is still in the Hadamard-transformed space (no inverse Hadamard is applied). - """ - # Pick a device — we require CUDA - device = x.device - if device.type != "cuda": - raise RuntimeError( - f"mxfp4_forward_kernel_wrapper requires a CUDA tensor for 'x', " - f"but got device '{device.type}'. Please move inputs to CUDA before calling." - ) - - # Ensure hadamard_matrix is on the same CUDA device - if hadamard_matrix.device != device: - hadamard_matrix = hadamard_matrix.to(device) - - dtype = hadamard_matrix.dtype - - if x.dtype != dtype: - x = x.to(dtype) - - # Make sure inputs are contiguous - x = x.contiguous() - hadamard_matrix = hadamard_matrix.contiguous() - - # Create output tensors on CUDA - output = torch.empty_like(x, device=device) - if return_clip_mask: - clip_mask = torch.empty_like(x, dtype=torch.bool, device=device).contiguous() - else: - clip_mask = None - - # Get total number of elements and calculate grid for launching the kernel - n_elements = x.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - # Launch kernel – no need for `with torch.device(...)` - mxfp4_forward_kernel[grid]( - x_ptr=x, - hadamard_matrix_ptr=hadamard_matrix, - output_ptr=output, - clip_mask_ptr=clip_mask, - n_elements=n_elements, - hadamard_dim=hadamard_matrix.shape[-1], - group_size=32, - gaussian_scale=gaussian_scale, - quest=quest, - ) - return output, clip_mask +__all__ = ["mxfp4_forward_kernel", "mxfp4_forward_kernel_wrapper"] diff --git a/auto_round/experimental/transform/utils/hadamard.py b/auto_round/experimental/transform/utils/hadamard.py index 320ae1832..8bed8a7c1 100644 --- a/auto_round/experimental/transform/utils/hadamard.py +++ b/auto_round/experimental/transform/utils/hadamard.py @@ -1,151 +1,18 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -import math -from pathlib import Path - -import torch -from safetensors import safe_open - -REPO_PATH = Path(__file__).parent / "hadamards.safetensors" +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.utils.math`. +""" +from auto_round.algorithms.transforms.rotation.utils.math import _HADAMARD_MATRICES_PATH as REPO_PATH # noqa: F401 +from auto_round.algorithms.transforms.rotation.utils.math import ( + _fetch_hadamard_divisor, + _matmul_hadU, + deterministic_hadamard_matrix, + is_pow2, + random_hadamard_matrix, +) __all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"] - - -# note that hadamard matrix multiplication reuses the code from -# https://github.com/vllm-project/compressed-tensors/blob/main/src/compressed_tensors/transform/utils/hadamard.py - - -def deterministic_hadamard_matrix( - size: int, - dtype: torch.dtype = torch.bfloat16, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """ - Construct an n-by-n Hadamard matrix, using Sylvester's construction. - `n` must be a power of 2. - - Adapted from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501 - - :param size: order of the matrix, must be a power of 2 - :param dtype: data type of matrix - :param device: device to construct matrix on - :return: hadamard matrix of size `size` - """ - if size <= 0: - raise ValueError("Cannot construct deterministic hadamard of size <= 0") - - log2 = int(math.log2(size)) - if size != 2**log2: - raise ValueError("Cannot construct deterministic hadamard of size != 2^n") - - H = torch.tensor([[1]], dtype=dtype, device=device) - - # Sylvester's construction - for _ in range(log2): - H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H)))) - - return H - - -def random_hadamard_matrix( - size: int, - dtype: torch.dtype = torch.bfloat16, - device: torch.device = torch.device("cpu"), - gen: torch.Generator | None = None, -) -> torch.Tensor: - """ - Produces a randomly generated Hadamard matrix. Differs from - `deterministic_hadamard_matrix` in that this function supports non powers of 2 - and randomization using a seeded generator - - Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501 - Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices - http://www.neilsloane.com/hadamard/ # noqa: E501 - - :param size: The dimension of the hamadard matrix - :param dtype: data type of matrix - :param device: device to construct matrix on - :param gen: Optional generator random values - :return: randomly generated hadamard matrix - """ - Q = torch.randint(low=0, high=2, size=(size,), generator=gen, device="cpu") # cpu - Q = Q.to(device=device, dtype=dtype) - Q = Q * 2 - 1 - Q = torch.diag(Q) - return _matmul_hadU(Q) - - -def is_pow2(n: int) -> bool: - """ - Check if a number is a power of 2 - - :param n: number to check - :return: True iff `n` is a power of 2 - """ - return n > 0 and (n & (n - 1) == 0) - - -def _fetch_hadamard_divisor( - n: int, - dtype: torch.dtype, - device: torch.device = torch.device("cpu"), - file_path: str = REPO_PATH, -) -> torch.Tensor | None: - """ - Fetch a known hadamard matrix from the given file path. The returned matrix will - be of of size `k` such that `n / k` is a power of two. Return None if no such - matrix exists. - - Note: This function reopens the safetensors file every time it is called. - This is technically inefficient, but a very small runtime cost and simpler - than forcing callers to manage the file open context - - :param n: size of known hadamard matrix - :param dtype: data type to move fetched hadamard to - :param device: device to move fetched hadamard to - :return: a known hadamard matrix of size `n` if one exists, else None - """ - open_device = torch.device("cpu") if device.type == "meta" else device - with safe_open(file_path, framework="pt", device=str(open_device)) as file: - divisors = sorted((int(key) for key in file.keys()), reverse=True) - for divisor in divisors: - if n % divisor == 0 and is_pow2(n // divisor): - return file.get_tensor(str(divisor)).to(dtype=dtype, device=device) - - return None - - -def _matmul_hadU(X: torch.Tensor) -> torch.Tensor: - size = X.size(0) - dtype = X.dtype - device = X.device - - # Check if we have the determined hadamard matrix - hadK = _fetch_hadamard_divisor(size, dtype, device=device) - if hadK is None: - raise ValueError(f"Cannot construct random hadamard matrix of size {size}") - K = hadK.size(0) - - # Reshape diag matrix with randomized -1/+1 - input = X.clone().view(-1, size, 1) - output = input.clone() - while input.shape[1] > K: - input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) - output = output.view(input.shape) - output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] - output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] - output = output.view(input.shape[0], input.shape[1], -1) - input, output = (output, input) - assert input.shape[1] == K - del output - - # Do not explicitly repeat - OOM - # input = torch.bmm( - # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) - # Use bcast instead - input = hadK.view(1, K, K).to(input) @ input - - # normalize - return input.view(X.shape) diff --git a/auto_round/experimental/transform/utils/matrix.py b/auto_round/experimental/transform/utils/matrix.py index 46d684e80..184d91055 100644 --- a/auto_round/experimental/transform/utils/matrix.py +++ b/auto_round/experimental/transform/utils/matrix.py @@ -1,98 +1,17 @@ # # Copyright (C) 2026 Intel Corporation # # SPDX-License-Identifier: Apache-2.0 +"""Backward-compat re-export shim. -import torch +The canonical implementation now lives in +:mod:`auto_round.algorithms.transforms.rotation.utils.matrix`. +""" -__all__ = ["apply_transform_weight"] - -# note that apply_transform_weight reuses some code from -# https://github.com/vllm-project/compressed-tensors/blob/main/src/compressed_tensors/transform/utils/matrix.py - - -def apply_transform_weight( - transform_weight: torch.Tensor, - value: torch.Tensor, - location: str, - module_type: type[torch.nn.Module], -) -> torch.Tensor: - """ - Using the transform location, apply the transform_weight to the - given value wrt linear weights. For more info on input and output transforms, - see `TransformLocation` - - The following explains how weights should be applied to values according to location - - let x be input activation - W be weight, - yh, xh, Wh be transformed output, input, weight - - note that - y = (x W.T) // torch.nn.Linear - - Choose values for yh, xh, and Wh which incorporate matrix transforms - - let V, Vi be transform matrices on input side - U, Ui be transform matrices on output side - - pick xh = (x V) - Wh = (U.T W Vi.T) - yh = (y U) +from auto_round.algorithms.transforms.rotation.utils.matrix import ( # noqa: F401 + apply_transform_weight, + multihead_matmul, +) - The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh +# Old private name kept for backward compatibility. +_multihead_matmul = multihead_matmul - (xh) (Wh).T = (x V) (U.T W Vi.T).T - = (x V) (Vi W.T U) // transpose matrix product identity - = (x W.T) U - = y U - = yh - - :param transform_weight: transform weight to apply - :param value: value to apply transform_weight to - :param location: determines how weight should be applied - :param model_type: result of type(module), passed in to determine application of - weight transform - :return: value after transform_weight has been applied - """ - - if location == "input": - return _multihead_matmul(value, transform_weight) - - if module_type == torch.nn.Linear: - return _multihead_matmul(value, transform_weight.T) - - -def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - """ - Performs A @ B for last two dims of two matrices A and B that possibly - have different shapes, as is the case in multi-headed dimension. If - shapes are different, this is equivalent to converting the last two dims - of the smaller matrix into a block-diagonal matrix with the same shape as - the last two dims of the larger matrix. - - E.g. if A is half the size of B, this function will perform - [[A ] @ B - [ A]] - - If B is a third of the size of A, this function will perform - A @ [[B ] - [ B ] - [ B]] - - This function will error out if the shapes are not evenly divisible - - :param A: left-hand tensor - :param B: right-hand tensor - :return: result - """ - if A.shape[-1] > B.shape[-2]: - head_dim = B.shape[-2] - num_heads = A.shape[-1] // head_dim - A = A.unflatten(-1, (num_heads, head_dim)) - return (A @ B).flatten(-2, -1) - elif A.shape[-1] < B.shape[-2]: - head_dim = A.shape[-1] - num_heads = B.shape[-2] // head_dim - B = B.unflatten(-2, (num_heads, head_dim)) - return (A @ B).flatten(-3, -2) - else: - return A @ B +__all__ = ["apply_transform_weight"] diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py index ea1dade78..80e830bed 100644 --- a/auto_round/experimental/utils.py +++ b/auto_round/experimental/utils.py @@ -138,84 +138,21 @@ def is_triton_kernel_available(data_type: str) -> bool: def dump_group_size_to_rotation_config(rotation_config: str | dict | RotationConfig, group_size: int): - rotation_dict = to_dict_rotation_config(rotation_config) - if rotation_dict.get("block_size", None) is None: - rotation_dict["block_size"] = group_size - return rotation_dict + from auto_round.algorithms.transforms.rotation.config import dump_group_size_to_rotation_config as _impl + + return _impl(rotation_config, group_size) def to_dict_rotation_config(rotation_config: str | dict | RotationConfig): - if isinstance(rotation_config, str): - key = rotation_config.strip() - if not key: - return {} + from auto_round.algorithms.transforms.rotation.config import to_dict_rotation_config as _impl - if key == "default": - cfg_dict = {"hadamard_type": "hadamard"} - else: - cfg_dict = {"hadamard_type": key} - elif isinstance(rotation_config, RotationConfig): - cfg_dict = rotation_config.model_dump() - else: - cfg_dict = dict(rotation_config) - return cfg_dict + return _impl(rotation_config) def normalize_rotation_config(rotation_config: str | dict | RotationConfig | None, data_type: str) -> dict[str, Any]: - """ - Normalize and validate `rotation_config`. - - Supported input types: - - None -> {} - - dict -> validated via RotationConfig - - RotationConfig -> validated & converted to dict - - str -> shorthand for `hadamard_type` in HADAMARDS keys - - Additional behavior: - - If block_size is not set by user: - - mx_fp -> default block_size to 32 - - nv_fp -> default block_size to 16 - - other data types -> emit a warning - - If block_size is set but does not match the recommended value: - - mx_fp expects 32 - - nv_fp expects 16 - - emit a warning - """ + from auto_round.algorithms.transforms.rotation.config import normalize_rotation_config as _impl - def _apply_data_type_block_size(cfg_dict: dict[str, Any], block_size_explicitly_set: bool) -> dict[str, Any]: - block_size = cfg_dict.get("block_size") - - if not block_size_explicitly_set or block_size is None: - if is_mx_fp(data_type): - cfg_dict["block_size"] = 32 - logger.warning("block_size is not set for data_type 'mx_fp'; defaulting to 32.") - elif is_nv_fp(data_type): - cfg_dict["block_size"] = 16 - logger.warning("block_size is not set for data_type 'nv_fp'; defaulting to 16.") - else: - logger.warning( - f"block_size is not set and cannot be inferred for data_type {data_type!r}; " - "please set block_size explicitly in rotation_config if needed." - ) - else: - if is_mx_fp(data_type) and block_size != 32: - logger.warning(f"data_type is 'mx_fp' but block_size={block_size}; recommended value is 32.") - elif is_nv_fp(data_type) and block_size != 16: - logger.warning(f"data_type is 'nv_fp' but block_size={block_size}; recommended value is 16.") - - return cfg_dict - - # 1) None -> {} - if rotation_config is None: - return {} - - rotation_dict = to_dict_rotation_config(rotation_config) - block_size_explicitly_set = "block_size" in rotation_dict - cfg_dict = _apply_data_type_block_size(rotation_dict, block_size_explicitly_set) - try: - return RotationConfig.model_validate(cfg_dict).model_dump() - except Exception as e: - raise ValueError(f"Invalid RotationConfig: {e}") from e + return _impl(rotation_config, data_type) def check_supported_schemes(scheme: str): diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index 5594b5979..c46b8c947 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -205,7 +205,7 @@ def save_quantized_as_autogptq( # --- 1️⃣ Extract inputs & configs --- quantization_config = serialization_dict - quant_block_list = serialization_dict.get("quant_block_list", get_block_names(model)) + quant_block_list = serialization_dict.get("quant_block_list") or get_block_names(model) processor = kwargs.get("processor") image_processor = kwargs.get("image_processor") safe_serialization = kwargs.get("safe_serialization", True) @@ -249,7 +249,7 @@ def save_quantized_as_autogptq( continue # Handle block layers if in_blocks or (block_name_to_quantize and check_start_with_block_name(layer_name, block_name_to_quantize)): - neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + neq_keys = check_neq_config(cfg, **{k: quantization_config.get(k) for k in scheme_keys}) if neq_keys: if matches_any_regex(layer_name, regex_config): continue diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 60613b7e8..94a439975 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import copy import functools import inspect @@ -285,7 +284,7 @@ def save_quantized_as_autoround( elif cfg["in_blocks"] or ( block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize) ): - neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + neq_keys = check_neq_config(cfg, **{k: quantization_config.get(k) for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} for key in neq_keys: @@ -296,7 +295,7 @@ def save_quantized_as_autoround( if regex_config is not None: for name, cfg in regex_config.items(): regex_name = to_standard_regex(name) - neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + neq_keys = check_neq_config(cfg, **{k: quantization_config.get(k) for k in scheme_keys}) if len(neq_keys) > 0: extra_config[regex_name] = {} for key in neq_keys: diff --git a/auto_round/export/export_to_autoround/export_to_fp8.py b/auto_round/export/export_to_autoround/export_to_fp8.py index 90c228583..200e10a94 100644 --- a/auto_round/export/export_to_autoround/export_to_fp8.py +++ b/auto_round/export/export_to_autoround/export_to_fp8.py @@ -228,7 +228,7 @@ def save_quantized_as_autoround( elif cfg["in_blocks"] or ( block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize) ): - neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + neq_keys = check_neq_config(cfg, **{k: quantization_config.get(k) for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} for key in neq_keys: diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mx.py b/auto_round/export/export_to_autoround/export_to_nvfp_mx.py index 506a6a85d..6c2926cfb 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mx.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mx.py @@ -214,7 +214,7 @@ def save_quantized_as_fp( elif cfg["in_blocks"] or ( block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize) ): - neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + neq_keys = check_neq_config(cfg, **{k: quantization_config.get(k) for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} for key in neq_keys: @@ -225,7 +225,7 @@ def save_quantized_as_fp( if regex_config is not None: for name, cfg in regex_config.items(): regex_name = to_standard_regex(name) - neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + neq_keys = check_neq_config(cfg, **{k: quantization_config.get(k) for k in scheme_keys}) if len(neq_keys) > 0: extra_config[regex_name] = {} for key in neq_keys: diff --git a/auto_round/export/export_to_gguf/export.py b/auto_round/export/export_to_gguf/export.py index c3a06c2ad..5f6dd8453 100644 --- a/auto_round/export/export_to_gguf/export.py +++ b/auto_round/export/export_to_gguf/export.py @@ -146,11 +146,7 @@ def pack_gguf_layer( ): """Export the model to gguf format.""" global gguf_model_instance_global - # if output_dir is not None and os.path.exists(output_dir): - # logger.warning_once(f"{output_dir} already exists, this may cause model conflict") if "gguf_model_instance_global" not in globals(): - config = model.config - gguf_model_instance_global = [ create_model_class( output_dir, diff --git a/auto_round/export/export_to_llmcompressor/config.py b/auto_round/export/export_to_llmcompressor/config.py index 546b9addb..62017ec73 100644 --- a/auto_round/export/export_to_llmcompressor/config.py +++ b/auto_round/export/export_to_llmcompressor/config.py @@ -28,16 +28,17 @@ from auto_round.utils import logger -def check_compressed_tensors_supported(): # pragma: no cover +def check_compressed_tensors_supported(raise_error: bool = False): # pragma: no cover try: import compressed_tensors # noqa: F401 return True except ImportError: - logger.error( - "Please install compressed-tensors via 'pip install compressed-tensors'" " to save as llm-compressor format" - ) - exit(-1) + msg = "Please install compressed-tensors via 'pip install compressed-tensors' to save as llm-compressor format" + logger.error(msg) + if raise_error: + raise ImportError(msg) from None + return False if check_compressed_tensors_supported(): @@ -64,7 +65,7 @@ def initialize_quantization(scheme, targets=["Linear"], config_groups=None, kv_c kv_cache_scheme = kv_cache_scheme ignore = ignore using_mxfp4_for_mxfp8 = False - check_compressed_tensors_supported() + check_compressed_tensors_supported(raise_error=True) if scheme is not None and config_groups is not None: raise ValueError("Please specify either `scheme` or `config_groups`") diff --git a/auto_round/export/export_to_llmcompressor/export_to_fp.py b/auto_round/export/export_to_llmcompressor/export_to_fp.py index b33bbb951..b18d9b149 100644 --- a/auto_round/export/export_to_llmcompressor/export_to_fp.py +++ b/auto_round/export/export_to_llmcompressor/export_to_fp.py @@ -272,7 +272,7 @@ def save_quantized_as_fp( ignore.append(lm_head_name) # get llm-compressor format config - check_compressed_tensors_supported() + check_compressed_tensors_supported(raise_error=True) # Detect mixed precision by grouping quantized layers by (bits, data_type) scheme_groups = {} # (bits, data_type) -> list of layer names diff --git a/auto_round/export/export_to_llmcompressor/export_to_static_fp.py b/auto_round/export/export_to_llmcompressor/export_to_static_fp.py index c733c4510..3f5ee91f3 100644 --- a/auto_round/export/export_to_llmcompressor/export_to_static_fp.py +++ b/auto_round/export/export_to_llmcompressor/export_to_static_fp.py @@ -153,7 +153,7 @@ def save_quantized_as_static_fp( pack_layer(name, model, serialization_dict.get("data_type", "fp8"), device) # Get llm-compressor format config - check_compressed_tensors_supported() + check_compressed_tensors_supported(raise_error=True) from compressed_tensors.quantization import ( # pylint: disable=E0401 QuantizationArgs, QuantizationConfig, diff --git a/auto_round/export/utils.py b/auto_round/export/utils.py index a71aad2ab..c740791ef 100644 --- a/auto_round/export/utils.py +++ b/auto_round/export/utils.py @@ -15,6 +15,7 @@ import os import shutil +import torch import torch.nn as nn import auto_round.envs as envs @@ -26,6 +27,21 @@ ) +def _save_model_configs(model: nn.Module, save_dir: str) -> None: + if hasattr(model, "config") and model.config is not None: + model.config.save_pretrained(save_dir) + + if hasattr(model, "generation_config") and model.generation_config is not None: + model.generation_config.save_pretrained(save_dir) + + +def _state_dict_has_meta_tensor(model: nn.Module) -> bool: + for tensor in model.state_dict().values(): + if isinstance(tensor, torch.Tensor) and tensor.device.type == "meta": + return True + return False + + def is_local_pipeline_model_dir(model_dir: str) -> bool: if not model_dir or not os.path.isdir(model_dir): return False @@ -184,13 +200,17 @@ def save_model( os.makedirs(save_dir, exist_ok=True) if unsupported_meta_device(model): - if hasattr(model, "config") and model.config is not None: - model.config.save_pretrained(save_dir) - - if hasattr(model, "generation_config") and model.generation_config is not None: - model.generation_config.save_pretrained(save_dir) + _save_model_configs(model, save_dir) else: - model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) + has_meta_tensor = _state_dict_has_meta_tensor(model) + if has_meta_tensor: + logger.info( + "Detected meta tensors in state_dict after shard-based saving; skipping model.save_pretrained and " + "saving configs only." + ) + _save_model_configs(model, save_dir) + else: + model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) # Allow disabling copy_missing_tensors_from_source via env var AR_DISABLE_COPY_MTP_WEIGHTS, default enabled if not envs.AR_DISABLE_COPY_MTP_WEIGHTS: @@ -330,7 +350,9 @@ def filter_quantization_config(quantization_config): "scale_dtype": "torch.float16", "seqlen": 2048, } - iters = quantization_config.get("iters", 200) + iters = quantization_config.get("iters") + if iters is None: + iters = 0 default_dict["lr"] = 1.0 / iters if iters > 0 else 5e-3 default_dict["minmax_lr"] = default_dict["lr"] @@ -349,13 +371,15 @@ def filter_quantization_config(quantization_config): quantization_config.pop("act_sym", None) quantization_config.pop("act_group_size", None) - clean_list = ("supported_types", "quant_block_list") + clean_list = ("supported_types", "quant_block_list", "transform_configs") for key in list(quantization_config.keys()): if callable(key): quantization_config.pop(key) elif isinstance(quantization_config[key], (list, tuple)): if any([callable(item) for item in quantization_config[key]]): quantization_config.pop(key) + elif len(quantization_config[key]) == 0: + quantization_config.pop(key) if key in clean_list and key in quantization_config: quantization_config.pop(key) return quantization_config diff --git a/auto_round/formats.py b/auto_round/formats.py index cd7f076ea..3a19fe2d6 100644 --- a/auto_round/formats.py +++ b/auto_round/formats.py @@ -45,7 +45,9 @@ get_gguf_scheme, ) from auto_round.utils import ( + INNER_SUPPORTED_LAYER_TYPES, SUPPORTED_FORMATS, + SUPPORTED_LAYER_TYPES, check_to_quantized, compress_layer_names, copy_python_files_from_model_cache, @@ -158,7 +160,7 @@ def _check_divisible_by_32(ar): skipped_layers = [] if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16: for n, m in ar.model.named_modules(): - if type(m) in ar.supported_types or m.__class__.__name__ in ar.inner_supported_types: + if type(m) in SUPPORTED_LAYER_TYPES or m.__class__.__name__ in INNER_SUPPORTED_LAYER_TYPES: if m.weight.shape[0] % 32 or m.weight.shape[1] % 32: if ar.layer_config is None: ar.layer_config = {} @@ -378,7 +380,7 @@ def __init__(self, format, ar): if is_nv_fp(ar.data_type) or is_mx_fp(ar.data_type): from auto_round.export.export_to_llmcompressor import check_compressed_tensors_supported - check_compressed_tensors_supported() + check_compressed_tensors_supported(raise_error=True) self.backend = LLMCompressorFormat(ar.data_type, ar) elif is_dynamic_afp8(ar) and is_block_wfp8(ar): self.backend = LLMCompressorFormat(AutoRoundExportFormat.FP8_BLOCK.value, ar) @@ -690,7 +692,7 @@ def check_and_reset_format(self, ar): if not awq_supported: logger.warning(f"The AutoAWQ format may not be supported due to {info}") if ar.bits != 4: - raise ValueError("The AWQ format only supports W4 quantization ") + raise ValueError(f"auto_awq format support quantization scheme with W4A16 but got bits={ar.bits}") if self.backend is None: _check_divisible_by_32(ar) @@ -759,6 +761,7 @@ class GGUFFormat(OutputFormat): def __init__(self, format: str, ar: BaseCompressor): if format.startswith("gguf:"): + self._original_format = format # preserve "gguf:q2_k_mixed" etc. for Phase 2b self.gguf_args_check(ar, format, model_type=ModelType.TEXT) if ar.mllm: self.gguf_args_check(ar, format, model_type=ModelType.MMPROJ) @@ -794,14 +797,14 @@ def check_scheme_args(cls: OutputFormat, scheme: QuantizationScheme) -> bool: return True def check_and_reset_format(self, ar): - if ar.iters != 0 and ar.bits != 3 and not ar.enable_alg_ext: + if getattr(ar, "iters", 0) != 0 and ar.bits != 3 and not ar.enable_alg_ext: logger.warning_once( "`iters=0` is recommended when exporting to current GGUF format" " or add `enable_alg_ext` for better accuracy with much more tuning cost." " Please refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md" " for the accuracy results." ) - elif ar.bits >= 8 and ar.iters != 0: + elif ar.bits >= 8 and getattr(ar, "iters", 0) != 0: logger.warning_once("`iters=0` is recommended for bits>=8") if getattr(ar, "quant_nontext_module", False): diff --git a/auto_round/modeling/fused_moe/gpt_oss.py b/auto_round/modeling/fused_moe/gpt_oss.py index 959e01356..da68a57fc 100644 --- a/auto_round/modeling/fused_moe/gpt_oss.py +++ b/auto_round/modeling/fused_moe/gpt_oss.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import torch import transformers from packaging import version @@ -33,6 +32,7 @@ class GPTOssSingleExpert(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype | None = None): super().__init__() self.hidden_size = hidden_size @@ -101,17 +101,26 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: B, T, H = hidden_states.shape x = hidden_states.reshape(-1, H) - # Use the original router (it returns scores and indices already softmaxed over top-k) - router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k] + # Use the original router (it returns logits, scores and indices) + router_out = self.router(x) + if len(router_out) == 3: + _, router_scores, router_indices = router_out + else: + router_scores, router_indices = router_out final_hidden_states = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x) num_all_tokens, total_num_experts = x.size(0), self.num_experts mask_weights = torch.zeros((num_all_tokens, total_num_experts), dtype=x.dtype, device=x.device) - topk_ids, experts_mask = router_indices, router_scores - topk_ids = topk_ids.to(torch.int64) + topk_ids = router_indices.to(torch.int64) mask_weights.scatter_(-1, topk_ids, 1) + # Build per-expert routing score matrix: shape (num_experts, num_tokens) + # experts_mask[e, t] = router score of expert e for token t (0 if not selected) + expert_score_matrix = torch.zeros_like(mask_weights) + expert_score_matrix.scatter_(-1, topk_ids, router_scores) + expert_score_matrix = expert_score_matrix.transpose(0, 1) # (num_experts, num_tokens) + mask_weights = mask_weights[:num_all_tokens, :total_num_experts] mask_weights = mask_weights.transpose(0, 1) @@ -124,7 +133,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: mask_weight = mask_weights[expert_idx].unsqueeze(1) current_state_static = x * mask_weight expert_output = self.experts[expert_idx](current_state_static) - expert_output = expert_output * experts_mask[expert_idx].unsqueeze(1) + expert_output = expert_output * expert_score_matrix[expert_idx].unsqueeze(1) final_hidden_states += expert_output return final_hidden_states.view(B, T, H), router_scores.view(B * T, -1) diff --git a/auto_round/modeling/unfused_moe/__init__.py b/auto_round/modeling/unfused_moe/__init__.py index 3b9511731..a112a2a15 100644 --- a/auto_round/modeling/unfused_moe/__init__.py +++ b/auto_round/modeling/unfused_moe/__init__.py @@ -145,7 +145,10 @@ def get_file_path_via_model_name(model_or_path: str, file_name): def pre_check_config(model_name: str | torch.nn.Module, trust_remote_code: bool = True): if isinstance(model_name, str): - config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) + try: + config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) + except (OSError, EnvironmentError, ValueError): + return False elif isinstance(model_name, torch.nn.Module): config = getattr(model_name, "config", None) if config is None: diff --git a/auto_round/schemes.py b/auto_round/schemes.py index c3bec7ebf..a298575aa 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -13,15 +13,19 @@ # limitations under the License. import copy from copy import deepcopy -from dataclasses import dataclass, fields -from typing import Optional, Union +from dataclasses import asdict, dataclass, fields +from typing import TYPE_CHECKING, Any, Optional, Union import torch from auto_round.logger import logger +from auto_round.utils import SUPPORTED_DTYPES, infer_bits_by_data_type __all__ = ["QuantizationScheme", "get_gguf_scheme", "preset_name_to_scheme"] +if TYPE_CHECKING: + from auto_round.auto_scheme.gen_auto_scheme import AutoScheme + @dataclass class QuantizationScheme: @@ -40,7 +44,9 @@ class QuantizationScheme: @classmethod def from_dict(cls, config: dict): - return cls(**config) + field_names = {f.name for f in fields(cls)} + filtered_config = {k: v for k, v in config.items() if k in field_names} + return cls(**filtered_config) @classmethod def get_attributes(cls: "QuantizationScheme") -> list[str]: @@ -132,6 +138,147 @@ def is_preset_scheme(name: str) -> bool: return name.upper() in PRESET_SCHEMES +def _reconcile_bits_and_dtype(config: dict, prefix: str = ""): + """ + Harmonizes 'bits' and 'data_type' for weights or activations. + Ensures internal consistency by prioritizing data_type inference. + """ + + dt_key = f"{prefix}data_type" + bits_key = f"{prefix}bits" + + if config.get(dt_key) is None: + return + + # Infer the correct bit-width based on the data_type string + inferred_bits = infer_bits_by_data_type(config[dt_key]) + + if inferred_bits is not None and inferred_bits < 16: + # Check for conflict between user-specified bits and inferred bits + if inferred_bits != config.get(bits_key): + logger.warning(f"'{dt_key}' does not match '{bits_key}'. " f"Resetting '{bits_key}' to {inferred_bits}.") + config[bits_key] = inferred_bits + + # Normalize data_type (e.g., 'mx_fp4' -> 'mx') + for supported in SUPPORTED_DTYPES: + if config[dt_key] == f"{supported}{inferred_bits}": + config[dt_key] = supported + break + + +def _override_scheme_with_user_specify( + scheme: Union[str, dict, QuantizationScheme], user_scheme_overrides: dict[str, Any], return_str=True +) -> Union[str, QuantizationScheme]: + """ + Updates a base quantization scheme with user-provided overrides. + Handles GGUF formatting and synchronizes weight/activation parameters. + """ + # 1. GGUF special handling: map data_type suffix to GGUF scheme names + dt_override = user_scheme_overrides.get("data_type", "") + if ( + isinstance(scheme, QuantizationScheme) or (isinstance(scheme, str) and not scheme.startswith("gguf")) + ) and dt_override.endswith("_dq"): + if "bits" not in user_scheme_overrides: + raise KeyError(f"Must specify 'bits' when using data_type={dt_override}") + + bits = user_scheme_overrides["bits"] + suffix = "k" if bits == 6 else "k_s" + scheme = f"gguf:q{bits}_{suffix}" + + # 2. Convert input scheme to a dictionary for processing + if isinstance(scheme, QuantizationScheme): + scheme_dict = asdict(scheme) + elif isinstance(scheme, str): + normalized_name = scheme.strip("'\" ").upper() + if normalized_name.startswith("GGUF") and len(user_scheme_overrides) > 0: + logger.warning_once( + "When using GGUF scheme, user-specified overrides will be ignored to ensure format compatibility." + ) + user_scheme_overrides = {} + # If no overrides exist, return the normalized string immediately + if not user_scheme_overrides and return_str: + return normalized_name + scheme_dict = asdict(preset_name_to_scheme(normalized_name)) + else: + scheme_dict = scheme.copy() + + # 3. Apply overrides and define default behaviors + scheme_dict.update(user_scheme_overrides) + + if scheme_dict.get("act_dynamic") is None: + scheme_dict["act_dynamic"] = True + + # 4. Reconcile weight settings (bits vs data_type) + _reconcile_bits_and_dtype(scheme_dict) + + # 5. Fallback logic: Inherit activation settings from weight settings + scheme_dict["act_group_size"] = ( + scheme_dict.get("act_group_size") + if scheme_dict.get("act_group_size") is not None + else scheme_dict.get("group_size") + ) + scheme_dict["act_bits"] = scheme_dict.get("act_bits") or 16 + scheme_dict["act_sym"] = ( + scheme_dict.get("act_sym") if scheme_dict.get("act_sym") is not None else scheme_dict.get("sym") + ) + + # 6. Activation data_type logic + if scheme_dict.get("act_data_type") is None: + is_supported = scheme_dict["data_type"] in SUPPORTED_DTYPES + if is_supported and scheme_dict["act_bits"] < 16: + scheme_dict["act_data_type"] = scheme_dict["data_type"] + logger.info(f"Activation adopting weight data_type: {scheme_dict['data_type']}") + else: + scheme_dict["act_data_type"] = "float" + + # 7. Reconcile activation settings + _reconcile_bits_and_dtype(scheme_dict, prefix="act_") + + return QuantizationScheme.from_dict(scheme_dict) + + +def _parse_scheme( + scheme: Union[str, dict, QuantizationScheme, "AutoScheme"], user_scheme_overrides: dict[str, Any] +) -> tuple[Union[str, QuantizationScheme], bool]: + """ + Parses the final scheme. + """ + from auto_round.auto_scheme.gen_auto_scheme import AutoScheme + + is_auto_scheme = isinstance(scheme, AutoScheme) + if is_auto_scheme: + if not scheme.options: + raise ValueError("AutoScheme options cannot be empty") + else: + for option in scheme.options: + if isinstance(option, str): + if "mixed" in option: + raise ValueError(f"Mixed option {option} is not supported") + + # Map user overrides across all auto-scheme options + scheme.options = [_override_scheme_with_user_specify(opt, user_scheme_overrides) for opt in scheme.options] + + # Select the primary scheme for attribute binding (skipping BF16) + default_scheme = scheme.options[0] + for opt in scheme.options: + if opt == "BF16": + continue + if isinstance(opt, QuantizationScheme): + if opt.bits < 16 or (opt.act_bits and opt.act_bits < 16): + default_scheme = opt + break + else: + default_scheme = _override_scheme_with_user_specify(scheme, user_scheme_overrides) + + # Extract attributes from the chosen default_scheme + if isinstance(default_scheme, str): + final_attrs = _override_scheme_with_user_specify(default_scheme, user_scheme_overrides, return_str=False) + final_attrs = asdict(final_attrs) + else: + final_attrs = asdict(default_scheme) + return default_scheme, is_auto_scheme, final_attrs + + W4A16 = QuantizationScheme.from_dict( { "bits": 4, @@ -182,7 +329,6 @@ def is_preset_scheme(name: str) -> bool: } ) - W2A16G32 = QuantizationScheme.from_dict( { "bits": 2, @@ -252,7 +398,6 @@ def is_preset_scheme(name: str) -> bool: } ) - MXFP8 = QuantizationScheme.from_dict( { "bits": 8, @@ -377,7 +522,6 @@ def is_preset_scheme(name: str) -> bool: } ) - # For AutoScheme 16 bits options BF16 = QuantizationScheme.from_dict( { @@ -482,6 +626,18 @@ def get_gguf_scheme(scheme: Union[str, QuantizationScheme]) -> str: return scheme if isinstance(scheme, str): return "" + # AutoScheme is a lazy placeholder whose concrete option is resolved later + # (after model loading). It is never a GGUF scheme itself. + from auto_round.auto_scheme.gen_auto_scheme import AutoScheme + + if isinstance(scheme, AutoScheme): + # options is always a list after AutoScheme.__post_init__. + # If the primary option is a GGUF scheme, propagate it so that + # scale_dtype defaults to fp32 (GGUF convention). + primary = scheme.options[0] if scheme.options else None + if isinstance(primary, str) and primary.upper().startswith("GGUF"): + return primary + return "" for key, val in PRESET_SCHEMES.items(): # For q40 or q4_1 we only support it with str scheme, otherwise it will be matched incorrectly with W4G32 if not key.upper().startswith("GGUF") or ("0" in key or "1" in key): diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 1a336ee6f..264c77ab8 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -191,19 +191,20 @@ def patched_layer_forward( def _handle_special_model(model): - if hasattr(model, "config") and model.config.model_type == "deepseek_vl_v2": + model_type = getattr(getattr(model, "config", None), "model_type", None) + if model_type == "deepseek_vl_v2": from functools import partial model.forward = partial(_deepseek_vl2_forward, model) - if hasattr(model, "config") and model.config.model_type == "qwen2_5_omni": + if model_type == "qwen2_5_omni": from functools import partial model.forward = partial(_qwen2_5_omni_forward, model) - if hasattr(model, "config") and model.config.model_type == "qwen3_omni_moe": + if model_type == "qwen3_omni_moe": from functools import partial model.forward = partial(_qwen3_omni_moe_forward, model) - if hasattr(model, "config") and model.config.model_type == "gemma4": + if hasattr(model, "config") and model_type == "gemma4": import transformers from packaging import version diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index d49999225..865134b8b 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -1031,3 +1031,27 @@ def compress_layer_names(names: list) -> str: parts.extend(singles) parts.sort() return ", ".join(parts) + + +def infer_bits_by_data_type(data_type: str): + """Infer bits by data_type + + Args: + data_type (str): data_type + + Returns: + int: bits inferred by data_type, None means cannot infer correct bits by data_type + """ + from auto_round.utils import SUPPORTED_DTYPES + + if data_type is None: + return 16 + for supported_dtype in SUPPORTED_DTYPES: + if data_type.startswith(supported_dtype) and len(data_type) > len(supported_dtype): + ##first check the following two bits + suc_2str = data_type[len(supported_dtype) : len(supported_dtype) + 2] + if str.isdigit(suc_2str): + return int(suc_2str) + if str.isdigit(data_type[len(supported_dtype)]): + return int(data_type[len(supported_dtype)]) + return None diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index a4727a996..c05a17e79 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -40,7 +40,6 @@ "hpu": "HABANA_VISIBLE_MODULES", } - # Note on HPU usage: # There are two modes available for enabling auto-round on HPU: # 1. Compile Mode @@ -360,7 +359,16 @@ def is_valid_digit(s): return device -def get_device_and_parallelism(device: Union[str, torch.device, int]) -> tuple[str, bool]: +def get_device_and_parallelism(device: Union[str, torch.device, int, dict]) -> tuple[str, bool]: + if device is None: + device = detect_device(device) + return device, False + if isinstance(device, dict): + unique_devices = set(device.values()) + if len(unique_devices) == 1: + device = next(iter(unique_devices)) + else: + device = "auto" if isinstance(device, str): if device in ["cuda", "xpu", "hpu"]: device = detect_device(device) @@ -481,6 +489,7 @@ def __enter__(self): # Create and inject fake triton module class FakeTriton: + def __getattr__(self, name): return None @@ -646,6 +655,25 @@ def _clear_memory_for_cpu_and_cuda( _malloc_trim_counter = 0 +def _force_trim_malloc() -> None: + """Unconditionally release glibc heap pages back to the OS on Linux. + + Unlike :func:`_maybe_trim_malloc`, this ignores the call-count throttle and + always invokes ``malloc_trim(0)``. Use at critical lifecycle boundaries + (end of model loading, end of post_init, start of quantize loop) where a + one-time trim has a meaningful impact on peak RSS. + """ + if os.name != "posix": + return + if os.environ.get("AR_ENABLE_MALLOC_TRIM", "1") != "1": + return + try: + libc = ctypes.CDLL("libc.so.6") + libc.malloc_trim(0) + except Exception: + pass + + def _maybe_trim_malloc() -> None: """Optionally release glibc heap pages back to OS on Linux. @@ -678,6 +706,7 @@ def _maybe_trim_malloc() -> None: class ClearMemory: + def __init__(self, device_list: list | tuple | None = None): self.device_list = device_list @@ -689,6 +718,13 @@ def __call__( from auto_round.utils.device import is_hpex_available if is_hpex_available(): + # Clear CPU-side references so Python can reclaim them. + if isinstance(tensor, list): + for i in range(len(tensor)): + tensor[i] = None + tensor = None + gc.collect() + _force_trim_malloc() memory_monitor.update_hpu(device_list) return else: @@ -1789,6 +1825,7 @@ def dump_mem_usage(msg: str = "", log_level: str = "info"): """Decorator to dump memory usage before and after a function call.""" def decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): memory_monitor.update_cpu() diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 210905e84..04a8b0395 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -18,7 +18,7 @@ import re from collections import UserDict from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING, Union import psutil import torch @@ -28,7 +28,6 @@ from auto_round import envs from auto_round.export.export_to_gguf.config import ModelType from auto_round.logger import logger -from auto_round.schemes import QuantizationScheme from auto_round.utils.common import monkey_patch_model from auto_round.utils.weight_handler import ( _dequant_fp8_linear_weight, @@ -39,6 +38,9 @@ FIX_MISTRAL_REGEX_MODEL_TYPE_LIST = ["longcat_next"] +if TYPE_CHECKING: + from auto_round.schemes import QuantizationScheme + def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None: """This function is recommended to be used instead of module.weight = None. @@ -762,7 +764,10 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module], platform: str = No model_path = model_or_path if isinstance(model_or_path, str) else model_or_path.name_or_path # For dummy model, model_path could be "". - if model_path and not os.path.isdir(model_path): + # Only try to download if the path looks like a HF repo id (not a local filesystem path). + # Skip download for absolute paths or relative paths that contain current/parent dir markers. + _is_local_path = os.path.isabs(model_path) or model_path.startswith("./") or model_path.startswith("../") + if model_path and not os.path.isdir(model_path) and not _is_local_path: model_path = download_or_get_path(model_path, platform=platform) if isinstance(model_path, str): @@ -1129,6 +1134,7 @@ def check_to_quantized(config): bool: True if the configuration is valid for quantization (bits <= 8), False otherwise. """ + from auto_round.schemes import QuantizationScheme if isinstance(config, (dict, QuantizationScheme)): bits = config.get("bits", None) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 9b329624e..93324334b 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -744,7 +744,12 @@ def forward(self, x, *args, **kwargs): def wrapper_block( - block, enable_minmax_tuning, enable_norm_bias_tuning, enable_torch_compile=False, device="cpu", **kwargs + block, + enable_minmax_tuning, + enable_norm_bias_tuning, + enable_torch_compile=False, + device="cpu", + **kwargs, ): """Wraps the layers in the given block with a custom Wrapper module. diff --git a/benchmark_both.py b/benchmark_both.py new file mode 100644 index 000000000..320f18b11 --- /dev/null +++ b/benchmark_both.py @@ -0,0 +1,83 @@ +"""Quick A/B benchmark: old (compressors) vs new (compressors_new) architecture. + +Uses AR_DISABLE_NEW_ARCH env-var to toggle. Runs each configuration in a +subprocess to avoid cross-contamination, with a warmup run to fill OS page cache. +""" + +import json +import os +import subprocess +import sys +import time + +MODEL = "Qwen/Qwen3-0.6B" +ITERS = "200" +SCHEME = "W4A16" +DEVICE = "cuda:0" + +CMD_TEMPLATE = [ + sys.executable, + "-m", + "auto_round", + "--model_name", + MODEL, + "--scheme", + SCHEME, + "--iters", + ITERS, + "--device", + DEVICE, +] + + +def run_once(label: str, env_override: dict) -> float: + env = {**os.environ, **env_override} + print(f"\n{'='*60}") + print(f" Running: {label}") + print(f" AR_DISABLE_NEW_ARCH={env.get('AR_DISABLE_NEW_ARCH', 'unset')}") + print(f"{'='*60}", flush=True) + t0 = time.perf_counter() + proc = subprocess.run(CMD_TEMPLATE, env=env, capture_output=True, text=True) + elapsed = time.perf_counter() - t0 + if proc.returncode != 0: + print(f"STDERR:\n{proc.stderr[-2000:]}") + raise RuntimeError(f"{label} failed with rc={proc.returncode}") + print(f" {label}: {elapsed:.1f}s") + return elapsed + + +def main(): + # Warmup: fill OS page cache & JIT caches + print("Warmup run (old arch)...") + run_once("warmup", {"AR_DISABLE_NEW_ARCH": "1"}) + + # Interleaved runs to reduce bias + results = {"old": [], "new": []} + for trial in range(2): + if trial % 2 == 0: + first, second = ("old", "1"), ("new", "0") + else: + first, second = ("new", "0"), ("old", "1") + for label, flag in [first, second]: + t = run_once(f"{label} (trial {trial+1})", {"AR_DISABLE_NEW_ARCH": flag}) + results[label].append(t) + + # Summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for arch in ["old", "new"]: + times = results[arch] + avg = sum(times) / len(times) + print(f" {arch}: {[f'{t:.1f}' for t in times]} avg={avg:.1f}s") + old_avg = sum(results["old"]) / len(results["old"]) + new_avg = sum(results["new"]) / len(results["new"]) + diff_pct = (new_avg - old_avg) / old_avg * 100 + print(f"\n Diff: {diff_pct:+.1f}% (new vs old)") + print(f" {'PASS' if abs(diff_pct) < 5 else 'FAIL'} (threshold: ±5%)") + + json.dump(results, open("benchmark_results/latest.json", "w"), indent=2) + + +if __name__ == "__main__": + main() diff --git a/performance_ut.sh b/performance_ut.sh new file mode 100644 index 000000000..12e78906b --- /dev/null +++ b/performance_ut.sh @@ -0,0 +1,115 @@ +#!/bin/bash +set -euo pipefail + +PATTERN='[-a-zA-Z0-9_]*=' + +for i in "$@"; do + case $i in + --model_name=*) + model_name=$(echo $i | sed "s/${PATTERN}//") + ;; + --scheme=*) + scheme=$(echo $i | sed "s/${PATTERN}//") + ;; + *) + echo "Parameter $i not recognized." + exit 1 + ;; + esac +done + +readonly WORKSPACE_DIR="/auto-round" +readonly LOG_DIR="${WORKSPACE_DIR}/log_dir" +readonly PERF_SCRIPT_DIR="${WORKSPACE_DIR}/.azure-pipelines/scripts/performance" +readonly BASELINE_GIT_URL="git+https://github.com/intel/auto-round.git" +readonly ITERS=200 + +log_group_start() { echo "##[group]$1"; } +log_group_end() { echo "##[endgroup]"; } +log_info() { echo -e "[\033[32mINFO\033[0m] $1"; } +log_err() { echo -e "[\033[31mERROR\033[0m] $1" >&2; } + +function setup_environment() { + log_group_start "Set up environment..." + + export TZ='Asia/Shanghai' + export TQDM_MININTERVAL=60 + export HF_HUB_DISABLE_PROGRESS_BARS=1 + export UV_NO_PROGRESS=1 + export UV_SYSTEM_PYTHON=1 + + log_info "Creating log directory: ${LOG_DIR}" + mkdir -p "${LOG_DIR}" + + log_info "Downloading model: ${model_name}" + hf download "${model_name}" + + log_group_end +} + +function install_auto_round() { + local install_source=$1 + local mode_name=$2 + + log_group_start "Install requirements for [${mode_name}]..." + + ( + cd "${WORKSPACE_DIR}" + log_info "Uninstalling existing auto-round..." + uv pip uninstall auto-round || true + + log_info "Installing auto-round from: ${install_source}" + BUILD_HPU_ONLY=1 uv pip install "${install_source}" + ) + + log_group_end +} + +function run_performance_test() { + local test_mode=$1 + local log_file="${LOG_DIR}/perf_test_${test_mode}.log" + + log_group_start "Run ${test_mode} performance test (${scheme})..." + + ( + cd "${PERF_SCRIPT_DIR}" + log_info "Executing auto-round for ${scheme}. Logging to ${log_file}" + auto-round \ + --model_name "${model_name}" \ + --scheme "${scheme}" \ + --iters "${ITERS}" \ + --enable_torch_compile \ + --device hpu \ + --output_dir "./${test_mode}" 2>&1 | tee -a "${log_file}" + ) + + log_group_end +} + +function run_performance_check() { + log_group_start "Check performance results..." + + ( + cd "${PERF_SCRIPT_DIR}" + log_info "Executing check_performance.py" + python check_performance.py + ) + + log_group_end +} + +function main() { + setup_environment + + install_auto_round "." "current" + run_performance_test "current" + + install_auto_round "${BASELINE_GIT_URL}" "baseline" + run_performance_test "baseline" + + run_performance_check + + log_info "All tasks completed successfully." +} + +main "$@" \ No newline at end of file diff --git a/profile_rss_per_block.py b/profile_rss_per_block.py new file mode 100644 index 000000000..66b13ce59 --- /dev/null +++ b/profile_rss_per_block.py @@ -0,0 +1,203 @@ +"""Granular per-block RSS profiling for peak RAM regression diagnosis. + +Instruments both old and new architecture to measure RSS at key points +within the per-block quantization loop. + +Usage: + # New arch: + python profile_rss_per_block.py + # Old arch: + AR_DISABLE_NEW_ARCH=1 python profile_rss_per_block.py +""" + +import gc +import os +import resource +import sys +import time + + +def rss_mb(): + """Get current RSS in MB (no gc.collect - raw measurement).""" + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 # KB -> MB on Linux + + +def rss_mb_clean(): + """Get current RSS in MB after gc.collect.""" + gc.collect() + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 + + +# Use psutil for live RSS (ru_maxrss is peak, not current) +import psutil + +_proc = psutil.Process() + + +def live_rss_mb(): + """Current RSS in MB (not peak).""" + return _proc.memory_info().rss / (1024 * 1024) + + +def live_rss_mb_clean(): + gc.collect() + try: + import ctypes + + libc = ctypes.CDLL("libc.so.6") + libc.malloc_trim(0) + except Exception: + pass + return _proc.memory_info().rss / (1024 * 1024) + + +arch = os.environ.get("AR_DISABLE_NEW_ARCH", "0") +arch_label = "OLD" if arch == "1" else "NEW" +print(f"\n{'='*70}") +print(f" {arch_label} Architecture - Granular Per-Block RSS Profiling") +print(f"{'='*70}") +print(f"Before import RSS: {live_rss_mb():.1f} MB") + +# Monkey-patch to add instrumentation +if arch != "1": + # NEW ARCH: patch CalibCompressor._quantize_single_block + from auto_round.compressors_new import calib as calib_mod + + _orig_quantize_single_block = calib_mod.CalibCompressor._quantize_single_block + _orig_quantize_blocks = calib_mod.CalibCompressor._quantize_blocks + + _block_rss_log = [] + + def _patched_quantize_single_block(self, model, m, input_ids, input_others, q_input): + block_idx = len(_block_rss_log) + rss_before = live_rss_mb() + + result = _orig_quantize_single_block(self, model, m, input_ids, input_others, q_input) + + rss_after_return = live_rss_mb() + gc.collect() + rss_after_gc = live_rss_mb() + try: + import ctypes + + libc = ctypes.CDLL("libc.so.6") + libc.malloc_trim(0) + except Exception: + pass + rss_after_trim = live_rss_mb() + + entry = { + "block": block_idx, + "before": rss_before, + "after_return": rss_after_return, + "after_gc": rss_after_gc, + "after_trim": rss_after_trim, + "delta_return": rss_after_return - rss_before, + "delta_gc": rss_after_gc - rss_before, + "delta_trim": rss_after_trim - rss_before, + } + _block_rss_log.append(entry) + print( + f" Block {block_idx:2d}: before={rss_before:.1f} after_ret={rss_after_return:.1f} " + f"after_gc={rss_after_gc:.1f} after_trim={rss_after_trim:.1f} " + f"delta_ret={entry['delta_return']:+.1f} delta_trim={entry['delta_trim']:+.1f} MB", + flush=True, + ) + return result + + calib_mod.CalibCompressor._quantize_single_block = _patched_quantize_single_block + +else: + # OLD ARCH: patch LLMCompressor._quantize_block + from auto_round.compressors import base as base_mod + + _orig_quantize_block = base_mod.LLMCompressor._quantize_block + + _block_rss_log = [] + + def _patched_quantize_block(self, block, input_ids, input_others, q_input=None, device="cpu", auto_offload=True): + block_idx = len(_block_rss_log) + rss_before = live_rss_mb() + + result = _orig_quantize_block(self, block, input_ids, input_others, q_input, device, auto_offload) + + rss_after_return = live_rss_mb() + gc.collect() + rss_after_gc = live_rss_mb() + try: + import ctypes + + libc = ctypes.CDLL("libc.so.6") + libc.malloc_trim(0) + except Exception: + pass + rss_after_trim = live_rss_mb() + + entry = { + "block": block_idx, + "before": rss_before, + "after_return": rss_after_return, + "after_gc": rss_after_gc, + "after_trim": rss_after_trim, + "delta_return": rss_after_return - rss_before, + "delta_gc": rss_after_gc - rss_before, + "delta_trim": rss_after_trim - rss_before, + } + _block_rss_log.append(entry) + print( + f" Block {block_idx:2d}: before={rss_before:.1f} after_ret={rss_after_return:.1f} " + f"after_gc={rss_after_gc:.1f} after_trim={rss_after_trim:.1f} " + f"delta_ret={entry['delta_return']:+.1f} delta_trim={entry['delta_trim']:+.1f} MB", + flush=True, + ) + return result + + base_mod.LLMCompressor._quantize_block = _patched_quantize_block + +print(f"After import RSS: {live_rss_mb():.1f} MB") + +from auto_round import AutoRound + +print(f"After AutoRound import RSS: {live_rss_mb():.1f} MB") + +import shutil + +save_dir = "/tmp/profile_rss_output" +shutil.rmtree(save_dir, ignore_errors=True) + +print("\nCreating AutoRound instance...") +ar = AutoRound( + model="Qwen/Qwen3-0.6B", + scheme="FP8_STATIC", + iters=200, + nsamples=128, + enable_torch_compile=True, +) +print(f"After init RSS: {live_rss_mb():.1f} MB") +print(f"After init RSS (clean): {live_rss_mb_clean():.1f} MB") + +print("\nStarting quantize_and_save...\n") +model, folder = ar.quantize_and_save(output_dir=save_dir, format="llm_compressor") + +print(f"\n{'='*70}") +print(f" SUMMARY ({arch_label} Architecture)") +print(f"{'='*70}") +print(f"Final RSS: {live_rss_mb():.1f} MB") +print(f"Final RSS (clean): {live_rss_mb_clean():.1f} MB") +print("\nPer-block deltas (after return, after gc+trim):") +for e in _block_rss_log: + print( + f" Block {e['block']:2d}: delta_ret={e['delta_return']:+.1f} delta_trim={e['delta_trim']:+.1f} MB " + f"(abs: {e['after_trim']:.1f} MB)" + ) + +# Compute growth rate +if len(_block_rss_log) >= 2: + first = _block_rss_log[0]["after_trim"] + last = _block_rss_log[-1]["after_trim"] + n = len(_block_rss_log) - 1 + print(f"\nGrowth: {first:.1f} -> {last:.1f} MB over {n} blocks = {(last-first)/n:.1f} MB/block avg") + +print(f"\nPeak RSS (ru_maxrss): {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.1f} MB") + +shutil.rmtree(save_dir, ignore_errors=True) diff --git a/setup.py b/setup.py index 1b759fe83..6bb184bcb 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,6 @@ def fetch_requirements(path): # python setup.py hpu install ############################################################################### - HPU_REQUIREMENTS_FILE = "requirements-hpu.txt" HPU_INSTALL_CFG = { "include_packages": find_packages( @@ -144,7 +143,6 @@ def fetch_requirements(path): "install_requires": fetch_requirements(HPU_REQUIREMENTS_FILE), } - # Support legacy `python setup.py hpu install` invocation for backward compatibility. # For python -m build / uv build, use the BUILD_HPU_ONLY=1 environment variable instead. if __name__ == "__main__": @@ -186,5 +184,10 @@ def fetch_requirements(path): "License :: OSI Approved :: Apache Software License", ], include_package_data=True, - package_data={"": ["mllm/templates/*.json", "experimental/transform/utils/hadamards.safetensors"]}, + package_data={ + "": [ + "mllm/templates/*.json", + "algorithms/transforms/rotation/utils/hadamards.safetensors", + ] + }, ) diff --git a/test/envs.py b/test/envs.py index f998fd4a3..d071b6668 100644 --- a/test/envs.py +++ b/test/envs.py @@ -39,6 +39,10 @@ def is_optimum_available(): return importlib.util.find_spec("optimum") is not None +def is_compressed_tensors_available(): + return importlib.util.find_spec("compressed_tensors") is not None + + def is_ipex_available(): try: require_version("intel-extension-for-pytorch>=2.5") @@ -135,6 +139,16 @@ def require_optimum(test_case): return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case) +def require_compressed_tensors(test_case): + """ + Decorator marking a test that requires compressed-tensors. + + These tests are skipped when compressed-tensors isn't installed. + + """ + return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed-tensors")(test_case) + + def require_greater_than_050(test_case): """ Decorator marking a test that requires auto-round>=0.5.0. diff --git a/test/test_ark/test_model.py b/test/test_ark/test_model.py index ca0d8b4b6..476dc7ec9 100644 --- a/test/test_ark/test_model.py +++ b/test/test_ark/test_model.py @@ -44,14 +44,16 @@ def main_op(self, format, bits, group_size, sym, dtype, device, fast_cfg=True, t else: autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format=format) ##will convert to gptq model + _, saved_folder = autoround.quantize_and_save( + output_dir=quantized_model_path, format=format + ) ##will convert to gptq model quantization_config = AutoRoundConfig(backend="ark") model = AutoModelForCausalLM.from_pretrained( - quantized_model_path, dtype=dtype, device_map=device, quantization_config=quantization_config + saved_folder, dtype=dtype, device_map=device, quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(saved_folder) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=tar_acc, batch_size=32, limit=limit) torch.xpu.empty_cache() diff --git a/test/test_cpu/advanced/test_low_precision_input_model.py b/test/test_cpu/advanced/test_low_precision_input_model.py index 9b0779954..9087e0de0 100644 --- a/test/test_cpu/advanced/test_low_precision_input_model.py +++ b/test/test_cpu/advanced/test_low_precision_input_model.py @@ -97,6 +97,6 @@ def test_w4a16_to_mxfp4(self, tmp_path): iters=2, nsamples=2, ) - ar.quantize_and_save(tmp_path, format="llm_compressor") - model = transformers.AutoModelForCausalLM.from_pretrained(tmp_path) + _, quantized_model_path = ar.quantize_and_save(tmp_path, format="llm_compressor") + model = transformers.AutoModelForCausalLM.from_pretrained(quantized_model_path) assert model, "Failed to load the quantized model" diff --git a/test/test_cpu/backends/test_torch_backend.py b/test/test_cpu/backends/test_torch_backend.py index 14951f2e6..7e45c69c4 100644 --- a/test/test_cpu/backends/test_torch_backend.py +++ b/test/test_cpu/backends/test_torch_backend.py @@ -41,23 +41,25 @@ def test_torch_4bits_asym(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:gptqmodel") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:gptqmodel" + ) quantization_config = AutoRoundConfig(backend="torch") model = AutoModelForCausalLM.from_pretrained( quantized_model_path, dtype=torch.float16, device_map="cpu", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.35, batch_size=16, limit=10) torch.cuda.empty_cache() model = AutoModelForCausalLM.from_pretrained( - self.save_folder, dtype=torch.bfloat16, device_map="cpu", quantization_config=quantization_config + quantized_model_path, dtype=torch.bfloat16, device_map="cpu", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.35, batch_size=16, limit=10) torch.cuda.empty_cache() @@ -77,14 +79,16 @@ def test_torch_4bits_sym(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") ##will convert to gptq model + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round" + ) ##will convert to gptq model quantization_config = AutoRoundConfig(backend="torch") model = AutoModelForCausalLM.from_pretrained( quantized_model_path, dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.28, batch_size=32, limit=1000) torch.cuda.empty_cache() diff --git a/test/test_cpu/core/test_autoround.py b/test/test_cpu/core/test_autoround.py index dacd7bf4a..6ed626cd8 100644 --- a/test/test_cpu/core/test_autoround.py +++ b/test/test_cpu/core/test_autoround.py @@ -20,6 +20,7 @@ class TestAutoRound: + @classmethod def setup_class(self): model_name = opt_name_or_path @@ -372,14 +373,14 @@ def test_rtn(self, tiny_opt_model_path): bits, group_size, sym = 4, 128, True autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, iters=0, nsamples=1) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained( - self.save_folder, + quantized_model_path, torch_dtype=torch.float16, device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) def test_embed_quant(self, tiny_opt_model_path, dataloader): @@ -426,7 +427,9 @@ def test_fallback_layers(self, tiny_opt_model_path, dataloader): autoround.quantize() quantized_model_path = self.save_folder - autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) + _, quantized_model_path = autoround.save_quantized( + output_dir=quantized_model_path, format="auto_round", inplace=True, return_folders=True + ) model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) @@ -457,7 +460,9 @@ def test_fallback_layers_regex_awq(self, tiny_opt_model_path, dataloader): autoround.quantize() quantized_model_path = self.save_folder - autoround.save_quantized(output_dir=quantized_model_path, format="auto_awq", inplace=True) + _, quantized_model_path = autoround.save_quantized( + output_dir=quantized_model_path, format="auto_awq", inplace=True, return_folders=True + ) quantization_config = AutoRoundConfig() model = AutoModelForCausalLM.from_pretrained( @@ -492,7 +497,9 @@ def test_fallback_layers_regex_gptq(self, tiny_opt_model_path, dataloader): autoround.quantize() quantized_model_path = self.save_folder - autoround.save_quantized(output_dir=quantized_model_path, format="auto_gptq", inplace=True) + _, quantized_model_path = autoround.save_quantized( + output_dir=quantized_model_path, format="auto_round", inplace=True, return_folders=True + ) quantization_config = AutoRoundConfig() model = AutoModelForCausalLM.from_pretrained( @@ -527,7 +534,9 @@ def test_fallback_layers_regex_round(self, tiny_opt_model_path, dataloader): autoround.quantize() quantized_model_path = self.save_folder - autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) + _, quantized_model_path = autoround.save_quantized( + output_dir=quantized_model_path, format="auto_round", inplace=True, return_folders=True + ) quantization_config = AutoRoundConfig() model = AutoModelForCausalLM.from_pretrained( @@ -636,8 +645,8 @@ def test_quant_lm_head(self, tiny_untied_qwen_model_path): nsamples=1, disable_opt_rtn=True, ) - ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") - model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") assert "lm_head" in model.config.quantization_config.extra_config assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4 @@ -653,8 +662,8 @@ def test_quant_lm_head_layer_config(self, tiny_untied_qwen_model_path): disable_opt_rtn=True, layer_config=layer_config, ) - ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") - model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") assert "lm_head" in model.config.quantization_config.extra_config assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4 @@ -664,12 +673,6 @@ def test_compressor(self, tiny_qwen_vl_model_path): assert ar.optimizer == torch.optim.AdamW assert ar.mllm - # test old api - from auto_round import AutoRoundMLLM - - ar = AutoRoundMLLM(model_name) - assert ar.mllm - def test_attention_mask_in_dataset(self): from transformers import AutoTokenizer diff --git a/test/test_cpu/export/test_export.py b/test/test_cpu/export/test_export.py index 0f4ec3ce9..4f4c0d177 100644 --- a/test/test_cpu/export/test_export.py +++ b/test/test_cpu/export/test_export.py @@ -28,6 +28,7 @@ def _get_folder_size(path: str) -> float: class TestAutoRound: + @classmethod def setup_class(self): self.model_name = opt_name_or_path @@ -59,7 +60,7 @@ def test_autogptq_format(self, dataloader): ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") if group_size == -1: continue @@ -86,7 +87,7 @@ def test_autoround_format(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") if group_size == -1: continue @@ -111,7 +112,9 @@ def test_autoround_awq_format(self, dataloader): ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) # quantization_config = AutoRoundConfig( # backend="cpu" @@ -220,7 +223,7 @@ def test_static_afp8_export(self, static_kv_dtype): static_kv_dtype=static_kv_dtype, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") assert "model.decoder.layers.8.self_attn.k_proj.input_scale" in f.keys() assert "model.decoder.layers.8.self_attn.k_proj.weight_scale" in f.keys() @@ -280,7 +283,7 @@ def test_static_afp8_export(self, static_kv_dtype): act_group_size=0, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") assert "model.decoder.layers.8.self_attn.k_proj.input_scale" in f.keys() @@ -305,7 +308,7 @@ def test_static_fp8_attn(self): static_attention_dtype="fp8", ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") assert "model.decoder.layers.8.self_attn.k_proj.input_scale" in f.keys() assert "model.decoder.layers.8.self_attn.k_proj.weight_scale" in f.keys() @@ -363,7 +366,9 @@ def test_gptq_lmhead_export(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + compressed_model, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_gptq" + ) lm_head = compressed_model.lm_head assert hasattr(lm_head, "bits") and lm_head.bits == 4, "Illegal GPTQ quantization for lm_head layer" quantization_config = AutoRoundConfig() @@ -383,6 +388,7 @@ def test_export_format(self): self.model_name, scheme="FP8_STATIC", ) + autoround.post_init() format_list = get_formats("auto_round, llm_compressor, auto_round:llm_compressor", autoround) assert len(format_list) == 3 assert format_list[0].output_format == "auto_round" @@ -396,6 +402,7 @@ def test_export_format(self): self.model_name, scheme="W4A16", ) + autoround.post_init() format_list = get_formats("auto_round:auto_awq, auto_gptq", autoround) assert format_list[0].output_format == "auto_round" assert format_list[0].get_backend_name() == "auto_round:auto_awq" @@ -406,6 +413,7 @@ def test_export_format(self): model=self.model_name, scheme="INT8", ) + autoround.post_init() format_list = get_formats("llm_compressor, auto_round:llm_compressor", autoround) assert format_list[0].output_format == "llm_compressor" assert format_list[0].get_backend_name() == "llm_compressor:int8_w8a8" @@ -417,6 +425,7 @@ def test_export_format(self): model=self.model_name, scheme="INT8_W8A8", ) + autoround_old.post_init() format_list_old = get_formats("llm_compressor, auto_round:llm_compressor", autoround_old) assert format_list_old[0].output_format == "llm_compressor" assert format_list_old[0].get_backend_name() == "llm_compressor:int8_w8a8" @@ -433,6 +442,7 @@ def test_export_format_with_scheme(self, tiny_qwen_model_path): group_size=32, sym=True, ) + ar.post_init() with pytest.raises(ValueError, match="auto_awq format support quantization scheme with W4A16 but got bits=2"): get_formats("auto_round:auto_awq", ar) @@ -446,6 +456,7 @@ def test_export_format_with_scheme(self, tiny_qwen_model_path): group_size=32, sym=True, ) + ar.post_init() with pytest.raises(ValueError, match="but got data_type=fp, bits=4"): get_formats("auto_round:llm_compressor", ar) @@ -456,6 +467,7 @@ def test_export_format_with_scheme(self, tiny_qwen_model_path): group_size=256, sym=True, ) + ar.post_init() get_formats("auto_round:auto_awq", ar) def test_autoawq_qwen3_vl_infer(self, dataloader): @@ -469,7 +481,9 @@ def test_autoawq_qwen3_vl_infer(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, inplace=False, format="auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=False, format="auto_awq" + ) # Check items of modules_to_not_convert in quantization config quantization_config_path = f"{quantized_model_path}/quantization_config.json" @@ -508,7 +522,7 @@ def test_llmc_dynamic_wint8aint8_export(self, iters, use_dataloader, scheme, dat scheme=scheme, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") with safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") as f: assert "model.decoder.layers.8.self_attn.k_proj.weight_scale" in f.keys() assert f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype == torch.int8 @@ -536,7 +550,7 @@ def test_llmc_wint_a16_export(self, scheme, bits, group_size, sym): sym=sym, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") with safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") as f: # weights must be packed as int32 (compressed-tensors stores both int4 and int8 as torch.int32) weight = f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight_packed") diff --git a/test/test_cpu/export/test_gguf_format.py b/test/test_cpu/export/test_gguf_format.py index 945c28db0..9666bb286 100644 --- a/test/test_cpu/export/test_gguf_format.py +++ b/test/test_cpu/export/test_gguf_format.py @@ -46,7 +46,9 @@ def test_q4_0(self, tiny_qwen_model_path): ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, inplace=False, format="gguf:q4_0") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=False, format="gguf:q4_0" + ) gguf_file = os.listdir(quantized_model_path)[0] assert gguf_file.endswith(".gguf"), "Saved file is not in gguf format" # Accuracy test is covered in test_cuda/export/test_gguf_format.py::TestAutoRound::test_q4_0_accuracy @@ -59,10 +61,12 @@ def test_func(self): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, inplace=False, format="gguf:q*_1") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=False, format="gguf:q*_1" + ) assert autoround.group_size == 32 assert not autoround.sym - gguf_file = os.listdir(self.save_dir)[0] + gguf_file = os.listdir(quantized_model_path)[0] model = AutoModelForCausalLM.from_pretrained(quantized_model_path, gguf_file=gguf_file, device_map="auto") eval_generated_prompt(model, self.tokenizer) @@ -91,7 +95,9 @@ def test_q4_k_m(self, dataloader, tiny_qwen_model_path): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m,fake") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="gguf:q4_k_m,fake" + ) assert autoround.layer_config["model.layers.1.self_attn.v_proj"]["super_group_size"] == 16 assert autoround.layer_config["model.layers.1.self_attn.v_proj"]["data_type"] == "int_sym_dq" assert autoround.layer_config["model.layers.0.self_attn.v_proj"]["data_type"] == "int_asym_dq" @@ -134,9 +140,9 @@ def test_all_format(self, tiny_qwen_model_path): shutil.rmtree("../../tmp_autoround", ignore_errors=True) def test_vlm_gguf(self, tiny_qwen_vl_model_path): - from auto_round import AutoRoundMLLM + from auto_round import AutoRound - autoround = AutoRoundMLLM( + autoround = AutoRound( tiny_qwen_vl_model_path, iters=0, nsamples=8, @@ -144,8 +150,8 @@ def test_vlm_gguf(self, tiny_qwen_vl_model_path): quant_nontext_module=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") - assert "mmproj-model.gguf" in os.listdir(self.save_dir) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") + assert "mmproj-model.gguf" in os.listdir(quantized_model_path) for file_name in os.listdir(quantized_model_path): file_size = os.path.getsize(os.path.join(quantized_model_path, file_name)) / 1024**2 if file_name == "mmproj-model.gguf": @@ -154,9 +160,9 @@ def test_vlm_gguf(self, tiny_qwen_vl_model_path): assert file_size < 270, f"file size {file_size} MB is too large for non-quantized mmproj-model.gguf" def test_vlm_gguf_wo_quant_nontext_module(self, tiny_qwen_vl_model_path): - from auto_round import AutoRoundMLLM + from auto_round import AutoRound - autoround = AutoRoundMLLM( + autoround = AutoRound( tiny_qwen_vl_model_path, iters=0, nsamples=8, @@ -164,8 +170,8 @@ def test_vlm_gguf_wo_quant_nontext_module(self, tiny_qwen_vl_model_path): quant_nontext_module=False, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") - assert "mmproj-model.gguf" in os.listdir(self.save_dir) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") + assert "mmproj-model.gguf" in os.listdir(quantized_model_path) for file_name in os.listdir(quantized_model_path): file_size = os.path.getsize(os.path.join(quantized_model_path, file_name)) / 1024**2 if file_name == "mmproj-model.gguf": @@ -254,7 +260,7 @@ def test_q2k_mixed(self, tiny_qwen_moe_model_path): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q2_k_mixed") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q2_k_mixed") gguf_file = os.listdir(quantized_model_path)[0] file_size = os.path.getsize(os.path.join(quantized_model_path, gguf_file)) / 1024**2 assert file_size < 1150, f"file size {file_size} MB is too large for q2_k_mixed format" diff --git a/test/test_cpu/export/test_llmc_format.py b/test/test_cpu/export/test_llmc_format.py index 2a6a0a5ec..c00e5e6df 100644 --- a/test/test_cpu/export/test_llmc_format.py +++ b/test/test_cpu/export/test_llmc_format.py @@ -10,10 +10,14 @@ from auto_round.export.export_to_llmcompressor import export_to_fp as llmc_fp_export from auto_round.export.export_to_llmcompressor import export_to_static_fp as llmc_static_fp_export +from ...envs import is_compressed_tensors_available from ...helpers import forbid_threaded_packing, get_model_path, opt_name_or_path +pytestmark = pytest.mark.skipif(not is_compressed_tensors_available(), reason="test requires compressed-tensors") + class TestLLMC: + @classmethod def setup_class(self): self.model_name = get_model_path("stas/tiny-random-llama-2") @@ -52,7 +56,7 @@ def test_llmcompressor_fp8(self, tmp_path): nsamples=2, iters=0, ) - autoround.quantize_and_save(tmp_path, format="llm_compressor") + _, quantized_model_path = autoround.quantize_and_save(tmp_path, format="llm_compressor") # from vllm import LLM # model = LLM(tmp_path) # result = model.generate("Hello my name is") @@ -62,13 +66,13 @@ def test_llmcompressor_fp8(self, tmp_path): from safetensors import safe_open - config = json.load(open(os.path.join(tmp_path, "config.json"))) + config = json.load(open(os.path.join(quantized_model_path, "config.json"))) assert "group_0" in config["quantization_config"]["config_groups"] assert config["quantization_config"]["config_groups"]["group_0"]["input_activations"]["num_bits"] == 8 assert config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"] == "channel" assert config["quantization_config"]["quant_method"] == "compressed-tensors" - f = safe_open(os.path.join(tmp_path, "model.safetensors"), framework="pt") + f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") assert len(f.get_tensor("model.decoder.layers.0.fc1.weight_scale").shape) == 2 def test_autoround_llmcompressor_fp8(self, tmp_path): @@ -82,11 +86,11 @@ def test_autoround_llmcompressor_fp8(self, tmp_path): nsamples=2, iters=0, ) - autoround.quantize_and_save(tmp_path, format="auto_round:llm_compressor") + _, quantized_model_path = autoround.quantize_and_save(tmp_path, format="auto_round:llm_compressor") import json - config = json.load(open(os.path.join(tmp_path, "config.json"))) + config = json.load(open(os.path.join(quantized_model_path, "config.json"))) assert "group_0" in config["quantization_config"]["config_groups"] assert config["quantization_config"]["config_groups"]["group_0"]["input_activations"]["num_bits"] == 8 assert config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"] == "tensor" @@ -101,7 +105,7 @@ def test_mxfp8_llmcompressor_format(self, tiny_opt_model_path, tmp_path): disable_opt_rtn=True, scheme=scheme, ) - compressed_model, _ = ar.quantize_and_save(output_dir=tmp_path, format="llm_compressor") + compressed_model, tmp_path = ar.quantize_and_save(output_dir=tmp_path, format="llm_compressor") tmp_layer = compressed_model.model.decoder.layers[1].self_attn.q_proj assert ( hasattr(tmp_layer, "weight_scale") @@ -135,7 +139,7 @@ def test_mixed_precision_llmcompressor_format(self, tiny_opt_model_path, tmp_pat disable_opt_rtn=True, scheme=scheme, ) - ar.quantize_and_save(output_dir=tmp_path, format="llm_compressor") + _, tmp_path = ar.quantize_and_save(output_dir=tmp_path, format="llm_compressor") model = AutoModelForCausalLM.from_pretrained(tmp_path, torch_dtype="auto", trust_remote_code=True) op = model.model.decoder.layers[0].fc1 if op.quantization_scheme.targets != ["Linear"]: diff --git a/test/test_cpu/models/test_conv1d.py b/test/test_cpu/models/test_conv1d.py index 8a30b8207..717178e73 100644 --- a/test/test_cpu/models/test_conv1d.py +++ b/test/test_cpu/models/test_conv1d.py @@ -39,7 +39,7 @@ def test_quant(self, dataloader): ) autoround.quantize() - autoround.save_quantized(self.save_dir) + _, quantized_model_path = autoround.save_quantized(self.save_dir, return_folders=True) - model = AutoModelForCausalLM.from_pretrained(self.save_dir, device_map="cpu", trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu", trust_remote_code=True) model_infer(model, self.tokenizer) diff --git a/test/test_cpu/models/test_mllm.py b/test/test_cpu/models/test_mllm.py index 9fd6c69df..2c0c71bd4 100644 --- a/test/test_cpu/models/test_mllm.py +++ b/test/test_cpu/models/test_mllm.py @@ -4,13 +4,14 @@ import pytest from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration -from auto_round import AutoRoundMLLM +from auto_round import AutoRound from auto_round.utils import get_block_names from ...helpers import get_model_path, opt_name_or_path class FakeDataLoader: + def __init__(self): self.batch_size = 1 @@ -27,7 +28,8 @@ def __iter__(self): yield self.data -class TestAutoRoundMLLM: +class TestAutoRound: + @classmethod def setup_class(self): self.model_name = get_model_path("Qwen/Qwen2-VL-2B-Instruct") @@ -43,7 +45,7 @@ def setup_save_dir(self, tmp_path): def test_tune(self, tiny_qwen_vl_model_path): bits, group_size = 4, 128 - autoround = AutoRoundMLLM( + autoround = AutoRound( model=tiny_qwen_vl_model_path, bits=bits, group_size=group_size, @@ -64,7 +66,7 @@ def test_quant_vision(self, tiny_qwen_vl_model_path): ## bug need to fix tiny_qwen_vl_model_path, trust_remote_code=True, device_map="auto" ) bits, group_size = 4, 128 - autoround = AutoRoundMLLM( + autoround = AutoRound( model, tokenizer, processor=processor, @@ -120,7 +122,7 @@ def test_diff_dataset(self, tiny_qwen_vl_model_path): ) bits, group_size = 4, 128 dataset = ["dataset test", "list test"] - autoround = AutoRoundMLLM( + autoround = AutoRound( model, tokenizer, processor=processor, @@ -154,7 +156,7 @@ def test_str_input(self): ) bits, group_size = 4, 128 dataset = ["test pure text", "input for mllm"] - autoround = AutoRoundMLLM( + autoround = AutoRound( model, tokenizer, processor=processor, @@ -215,7 +217,7 @@ def test_qwen2_5(self, tiny_qwen_2_5_vl_model_path): model_name = tiny_qwen_2_5_vl_model_path model, processor, tokenizer, image_processor = mllm_load_model(model_name) - autoround = AutoRoundMLLM( + autoround = AutoRound( model, tokenizer, iters=1, @@ -225,15 +227,17 @@ def test_qwen2_5(self, tiny_qwen_2_5_vl_model_path): processor=processor, image_processor=image_processor, ) - autoround.quantize_and_save(self.save_dir, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(self.save_dir, format="auto_round") import requests from PIL import Image from transformers import AutoProcessor, AutoTokenizer, Qwen2_5_VLForConditionalGeneration - model = Qwen2_5_VLForConditionalGeneration.from_pretrained(self.save_dir, torch_dtype="auto", device_map="auto") + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + quantized_model_path, torch_dtype="auto", device_map="auto" + ) image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" - processor = AutoProcessor.from_pretrained(self.save_dir) + processor = AutoProcessor.from_pretrained(quantized_model_path) messages = [ { "role": "user", @@ -269,7 +273,7 @@ def test_mllm_early_stop_tracking(self, tiny_qwen_2_5_vl_model_path): model_name = tiny_qwen_2_5_vl_model_path model, processor, tokenizer, image_processor = mllm_load_model(model_name) - autoround = AutoRoundMLLM( + autoround = AutoRound( model, tokenizer, iters=1, diff --git a/test/test_cpu/models/test_moe_model.py b/test/test_cpu/models/test_moe_model.py index ea2e5a6ed..9c25b1d81 100644 --- a/test/test_cpu/models/test_moe_model.py +++ b/test/test_cpu/models/test_moe_model.py @@ -24,7 +24,7 @@ def quantize_model(model, output_dir, scheme, iters=0, ignore_layers="self_attn, disable_opt_rtn=disable_opt_rtn, ) quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) - return quantized_model + return quantized_model, save_folder def count_modules_by_type(model, target_module_name_or_class): @@ -44,7 +44,9 @@ def count_modules_by_type(model, target_module_name_or_class): def test_gptoss(scheme, tiny_gpt_oss_model_path, tmp_path): config = AutoConfig.from_pretrained(tiny_gpt_oss_model_path, trust_remote_code=True) output_dir = str(tmp_path / "saved") - quantized_model = quantize_model(tiny_gpt_oss_model_path, output_dir, scheme, ignore_layers="self_attn,lm_head") + quantized_model, save_folder = quantize_model( + tiny_gpt_oss_model_path, output_dir, scheme, ignore_layers="self_attn,lm_head" + ) # Ensure the quantized model is not None assert quantized_model is not None, "Quantized model should not be None." @@ -63,7 +65,7 @@ def test_gptoss(scheme, tiny_gpt_oss_model_path, tmp_path): ), f"Expected {config.num_hidden_layers * 3 * config.num_local_experts} QuantLinear modules, found {quant_linear_cnt}." # verify the quantized model can be loaded and run inference - loaded_model = GptOssForCausalLM.from_pretrained(output_dir) + loaded_model = GptOssForCausalLM.from_pretrained(save_folder) inp = torch.randint(0, 100, (1, 32)) with torch.inference_mode(): @@ -72,12 +74,14 @@ def test_gptoss(scheme, tiny_gpt_oss_model_path, tmp_path): def test_llama4(tiny_llama4_model_path): output_dir = "./tmp/test_quantized_llama4" - quantized_model = quantize_model(tiny_llama4_model_path, output_dir, "MXFP4", ignore_layers="self_attn,lm_head") + quantized_model, save_folder = quantize_model( + tiny_llama4_model_path, output_dir, "MXFP4", ignore_layers="self_attn,lm_head" + ) # Ensure the quantized model is not None assert quantized_model is not None, "Quantized model should not be None." - loaded_model = Llama4ForConditionalGeneration.from_pretrained(output_dir) + loaded_model = Llama4ForConditionalGeneration.from_pretrained(save_folder) inp = torch.randint(0, 100, (1, 32)) with torch.inference_mode(): @@ -97,9 +101,9 @@ def test_qwen3_vl_moe_mxfp(tiny_qwen3_vl_moe_model_path): disable_opt_rtn=True, ignore_layers="self_attn,lm_head, mlp.gate", ) - quantized_model, _ = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + quantized_model, quantized_model_path = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) assert quantized_model is not None, "Quantized model should not be None." - loaded_model = Qwen3VLMoeForConditionalGeneration.from_pretrained(output_dir, device_map="cpu") + loaded_model = Qwen3VLMoeForConditionalGeneration.from_pretrained(quantized_model_path, device_map="cpu") inp = torch.randint(0, 100, (1, 32)) with torch.inference_mode(): diff --git a/test/test_cpu/models/test_omni_model.py b/test/test_cpu/models/test_omni_model.py index ce3f43023..3c5d82548 100644 --- a/test/test_cpu/models/test_omni_model.py +++ b/test/test_cpu/models/test_omni_model.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Unit tests for Qwen2.5-Omni and Qwen3-Omni-MoE model support. Tests cover: @@ -282,6 +281,8 @@ def assert_same_weights(actual: torch.Tensor, expected: torch.Tensor): intermediate = 32 # moe_intermediate_size # Verify thinker expert weights + # Use equal_nan=True because fused expert parameters may contain + # uninitialized memory with NaN bit patterns, and NaN != NaN in IEEE 754. for i in range(4): expert = model.thinker.model.layers[0].mlp.experts[i] assert_same_weights(expert.gate_proj.weight.data, thinker_gate_up[i, :intermediate, :]) diff --git a/test/test_cpu/quantization/test_act_quantization.py b/test/test_cpu/quantization/test_act_quantization.py index 4798e0fe4..6546efd62 100644 --- a/test/test_cpu/quantization/test_act_quantization.py +++ b/test/test_cpu/quantization/test_act_quantization.py @@ -108,7 +108,7 @@ def test_act_config_MXFP4_saving(self, tiny_opt_model_path, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") assert "lm_head" not in model.config.quantization_config.extra_config @@ -129,7 +129,7 @@ def test_act_config_NVFP4_saving(self, tiny_opt_model_path, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"] assert "act_bits" in kproj_config.keys() and kproj_config["act_bits"] == 16 @@ -148,7 +148,7 @@ def test_WOQ_config_INT_saving(self, tiny_opt_model_path, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") extra_config = model.config.quantization_config.extra_config @@ -178,7 +178,7 @@ def test_act_config_FP8_saving(self, tiny_opt_model_path, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") from transformers import AutoConfig extra_config = AutoConfig.from_pretrained(quantized_model_path).quantization_config["extra_config"] diff --git a/test/test_cpu/quantization/test_asym.py b/test/test_cpu/quantization/test_asym.py index 05f9c1990..0f792a3a1 100644 --- a/test/test_cpu/quantization/test_asym.py +++ b/test/test_cpu/quantization/test_asym.py @@ -28,15 +28,15 @@ def test_asym_group_size(self, tiny_opt_model_path): ar = AutoRound( tiny_opt_model_path, bits=bits, group_size=group_size, sym=sym, iters=0, seqlen=2, nsamples=1 ) - ar.quantize_and_save(format="auto_round", output_dir=self.save_folder) + _, quantized_model_path = ar.quantize_and_save(format="auto_round", output_dir=self.save_folder) model = AutoModelForCausalLM.from_pretrained( - self.save_folder, + quantized_model_path, torch_dtype="auto", device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) def test_asym_bits(self, tiny_opt_model_path): @@ -45,15 +45,15 @@ def test_asym_bits(self, tiny_opt_model_path): ar = AutoRound( tiny_opt_model_path, bits=bits, group_size=group_size, sym=sym, iters=0, seqlen=2, nsamples=1 ) - ar.quantize_and_save(format="auto_round", output_dir=self.save_folder) + _, quantized_model_path = ar.quantize_and_save(format="auto_round", output_dir=self.save_folder) model = AutoModelForCausalLM.from_pretrained( - self.save_folder, + quantized_model_path, torch_dtype="auto", device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) # use parameters later @@ -70,13 +70,13 @@ def test_asym_format(self, tiny_opt_model_path): nsamples=1, disable_opt_rtn=True, ) - ar.quantize_and_save(format=format, output_dir=self.save_folder) + _, quantized_model_path = ar.quantize_and_save(format=format, output_dir=self.save_folder) model = AutoModelForCausalLM.from_pretrained( - self.save_folder, + quantized_model_path, torch_dtype="auto", device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) diff --git a/test/test_cpu/quantization/test_mix_bits.py b/test/test_cpu/quantization/test_mix_bits.py index 1fbbfed44..6eab008c0 100644 --- a/test/test_cpu/quantization/test_mix_bits.py +++ b/test/test_cpu/quantization/test_mix_bits.py @@ -59,7 +59,7 @@ def test_mixed_gptqmodel(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") # test original GPTQModel inference from gptqmodel import GPTQModel @@ -87,7 +87,7 @@ def test_mixed_gptqmodel_convert_to_ar(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") quantization_config = AutoRoundConfig() model = AutoModelForCausalLM.from_pretrained( quantized_model_path, device_map="cpu", quantization_config=quantization_config @@ -114,7 +114,7 @@ def test_mixed_autoround_format(self, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") config_file = Path(quantized_model_path) / "config.json" with open(config_file, "r", encoding="utf-8") as f: config = json.load(f) @@ -147,7 +147,7 @@ def test_fallback_regex_for_awq_format(self, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_awq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_awq") quantization_config = AutoRoundConfig() model = AutoModelForCausalLM.from_pretrained( quantized_model_path, device_map="cpu", quantization_config=quantization_config @@ -219,7 +219,9 @@ def test_mixed_MXFP_autoround_format_loading(self, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, inplace=False, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=False, format="auto_round" + ) model = AutoModelForCausalLM.from_pretrained( quantized_model_path, torch_dtype="auto", diff --git a/test/test_cpu/quantization/test_mxfp_nvfp.py b/test/test_cpu/quantization/test_mxfp_nvfp.py index d0c15ccf8..2b091a068 100644 --- a/test/test_cpu/quantization/test_mxfp_nvfp.py +++ b/test/test_cpu/quantization/test_mxfp_nvfp.py @@ -10,6 +10,7 @@ from auto_round import AutoRound from auto_round.export.export_to_autoround import export_to_nvfp_mx as autoround_nvfp_mx_export +from ...envs import require_compressed_tensors from ...helpers import forbid_threaded_packing, is_model_outputs_similar, transformers_version @@ -79,7 +80,9 @@ def test_nvfp4_moe_actmax_ar(self, tiny_deepseek_v2_model_path, dataloader): layer_config=layer_config, trust_remote_code=False, ) - compressed_model, _ = autoround.quantize_and_save(output_dir=self.save_dir, inplace=True, format="auto_round") + compressed_model, quantized_model_path = autoround.quantize_and_save( + output_dir=self.save_dir, inplace=True, format="auto_round" + ) lm_head = compressed_model.lm_head assert ( hasattr(lm_head, "weight_scale") @@ -88,7 +91,6 @@ def test_nvfp4_moe_actmax_ar(self, tiny_deepseek_v2_model_path, dataloader): and lm_head.weight_packed.dtype is torch.uint8 and lm_head.weight_scale.dtype is torch.float8_e4m3fn ), "Illegal NVFP4 packing for lm_head layer" - quantized_model_path = self.save_dir assert is_model_outputs_similar(model_name, quantized_model_path) def test_mxfp4_moe_ar(self, tiny_deepseek_v2_model_path, dataloader): @@ -119,6 +121,7 @@ def test_mxfp4_moe_ar(self, tiny_deepseek_v2_model_path, dataloader): and lm_head.weight_scale.dtype is torch.uint8 ), "Illegal MXFP4 packing for lm_head layer" + @require_compressed_tensors def test_mxfp4_llmcompressor_format(self, tiny_opt_model_path, dataloader): model_name = tiny_opt_model_path from transformers import AutoConfig @@ -157,6 +160,7 @@ def test_mxfp4_llmcompressor_format(self, tiny_opt_model_path, dataloader): and quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4 ), f"Invalid MXFP4 quantization configuration: {quantization_config}" + @require_compressed_tensors def test_rtn_mxfp4_llmcompressor_format(self, tiny_opt_model_path, dataloader): model_name = tiny_opt_model_path from transformers import AutoConfig @@ -195,6 +199,7 @@ def test_rtn_mxfp4_llmcompressor_format(self, tiny_opt_model_path, dataloader): and quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4 ), f"Invalid MXFP4 quantization configuration: {quantization_config}" + @require_compressed_tensors def test_mxfp8_llmcompressor_format(self, tiny_opt_model_path, dataloader): model_name = tiny_opt_model_path from transformers import AutoConfig @@ -208,7 +213,9 @@ def test_mxfp8_llmcompressor_format(self, tiny_opt_model_path, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") + compressed_model, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="llm_compressor" + ) tmp_layer = compressed_model.model.decoder.layers[1].self_attn.q_proj assert ( hasattr(tmp_layer, "weight_scale") @@ -230,6 +237,7 @@ def test_mxfp8_llmcompressor_format(self, tiny_opt_model_path, dataloader): 0.05 < folder_size_gb < 0.1 ), f"Quantized model folder size {folder_size_gb:.2f} GB is outside the expected range (0.05~0.1 GB)" + @require_compressed_tensors def test_nvfp4_llmcompressor_format(self, tiny_opt_model_path, dataloader): model_name = tiny_opt_model_path from transformers import AutoConfig @@ -243,7 +251,9 @@ def test_nvfp4_llmcompressor_format(self, tiny_opt_model_path, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") + compressed_model, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="llm_compressor" + ) tmp_layer = compressed_model.model.decoder.layers[1].self_attn.q_proj assert ( hasattr(tmp_layer, "weight_scale") @@ -279,7 +289,9 @@ def test_nvfp4_autoround_format(self, tiny_opt_model_path, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + compressed_model, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round" + ) tmp_layer = compressed_model.model.decoder.layers[1].self_attn.q_proj assert ( hasattr(tmp_layer, "weight_scale") @@ -334,7 +346,9 @@ def test_qwen_moe_quant_infer(self, tiny_qwen_moe_model_path, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, inplace=True, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=True, format="auto_round" + ) assert is_model_outputs_similar(model_name, quantized_model_path) @pytest.mark.parametrize( @@ -371,7 +385,7 @@ def test_fp8_kv_attn(self, scheme, static_kv_dtype, static_attention_dtype, tiny ) quantized_model_path = self.save_dir - compressed_model, _ = autoround.quantize_and_save( + compressed_model, quantized_model_path = autoround.quantize_and_save( output_dir=quantized_model_path, format="auto_round", ) diff --git a/test/test_cpu/quantization/test_mxfp_save_load.py b/test/test_cpu/quantization/test_mxfp_save_load.py index 25e5a2428..ea7e4478d 100644 --- a/test/test_cpu/quantization/test_mxfp_save_load.py +++ b/test/test_cpu/quantization/test_mxfp_save_load.py @@ -61,7 +61,7 @@ def test_e2e_quant_and_load(scheme_name, weight_data_type, act_data_type): # Quantize and save the model to the temporary directory quantized_model_path = f"{temp_dir}/tmp_autoround" - autoround.quantize_and_save(format="auto_round", output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(format="auto_round", output_dir=quantized_model_path) # Perform inference with the quantized model model = AutoModelForCausalLM.from_pretrained( diff --git a/test/test_cpu/schemes/test_scheme.py b/test/test_cpu/schemes/test_scheme.py index 8af28b7a5..41e9a9f7d 100644 --- a/test/test_cpu/schemes/test_scheme.py +++ b/test/test_cpu/schemes/test_scheme.py @@ -11,6 +11,7 @@ class TestAutoRound: + @pytest.fixture(autouse=True) def setup_save_folder(self, tmp_path): self.save_folder = str(tmp_path / "saved") @@ -35,10 +36,12 @@ def test_gguf(self, tiny_qwen_model_path, dataloader): def test_w4a16(self, tiny_opt_model_path, dataloader): ar = AutoRound(tiny_opt_model_path, scheme="W4A16", nsamples=1, iters=1, seqlen=2, dataset=dataloader) + ar.post_init() assert ar.bits == 4 def test_w2a16_rtn(self, tiny_opt_model_path, dataloader): ar = AutoRound(tiny_opt_model_path, scheme="W2A16", nsamples=1, iters=0, seqlen=2, dataset=dataloader) + ar.post_init() assert ar.bits == 2 def test_w4a16_mixed(self, tiny_qwen_moe_model_path, dataloader): @@ -56,13 +59,13 @@ def test_w4a16_mixed(self, tiny_qwen_moe_model_path, dataloader): low_cpu_mem_usage=False, layer_config=layer_config, ) - ar.quantize_and_save(self.save_folder) + _, quantized_model_path = ar.quantize_and_save(self.save_folder) assert ar.bits == 4 assert ar.model.model.layers[0].self_attn.q_proj.bits == 8 assert ar.model.model.layers[0].self_attn.k_proj.bits == 16 assert ar.model.model.layers[0].mlp.experts[0].up_proj.bits == 4 # assert ar.model.model.layers[0].mlp.shared_expert.gate_proj.bits == 8 # gate has been added to ignore_layers - model = transformers.AutoModelForCausalLM.from_pretrained(self.save_folder, trust_remote_code=True) + model = transformers.AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True) assert model is not None, "Model loading failed after quantization with W4A16_MIXED scheme on MoE" def test_w4a16_mixed_mllm(self, tiny_qwen_2_5_vl_model_path, dataloader): @@ -78,8 +81,8 @@ def test_w4a16_mixed_mllm(self, tiny_qwen_2_5_vl_model_path, dataloader): dataset=dataloader, low_cpu_mem_usage=False, ) - ar.quantize_and_save(self.save_folder) - model = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(self.save_folder) + _, quantized_model_path = ar.quantize_and_save(self.save_folder) + model = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(quantized_model_path) assert model is not None, "Model loading failed after quantization with W4A16_MIXED scheme on MLLM" assert ar.bits == 4 assert ar.model.model.language_model.layers[0].self_attn.q_proj.bits == 16 @@ -87,6 +90,7 @@ def test_w4a16_mixed_mllm(self, tiny_qwen_2_5_vl_model_path, dataloader): def test_mxfp4(self, tiny_opt_model_path, dataloader): ar = AutoRound(tiny_opt_model_path, scheme="MXFP4", nsamples=1, iters=1, seqlen=2, dataset=dataloader) + ar.post_init() assert ar.bits == 4 assert ar.act_bits == 4 assert ar.data_type == "mx_fp" @@ -94,29 +98,32 @@ def test_mxfp4(self, tiny_opt_model_path, dataloader): def test_mxfp4_rceil(self, tiny_opt_model_path): ar = AutoRound(tiny_opt_model_path, scheme="MXFP4_RCEIL", nsamples=1, iters=1) + ar.post_init() assert ar.bits == 4 assert ar.act_bits == 4 assert ar.data_type == "mx_fp" assert ar.act_data_type == "mx_fp_rceil" - ar.quantize_and_save() - model = transformers.AutoModelForCausalLM.from_pretrained("tmp_autoround", trust_remote_code=True) + _, quantized_model_path = ar.quantize_and_save() + model = transformers.AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True) assert model is not None, "Model loading failed after quantization with MXFP4 scheme" def test_vlm(self, tiny_qwen_vl_model_path): - from auto_round import AutoRoundMLLM + from auto_round import AutoRound - ar = AutoRoundMLLM(tiny_qwen_vl_model_path, scheme="W2A16", nsamples=1, iters=1, seqlen=2) + ar = AutoRound(tiny_qwen_vl_model_path, scheme="W2A16", nsamples=1, iters=1, seqlen=2) + ar.post_init() assert ar.bits == 2 assert ar.act_bits == 16 def test_nvfp4(self, tiny_opt_model_path, dataloader): ar = AutoRound(tiny_opt_model_path, scheme="NVFP4", nsamples=1, iters=1, seqlen=2, dataset=dataloader) + ar.post_init() assert ar.bits == 4 assert ar.act_bits == 4 assert ar.data_type == "nv_fp" assert ar.act_data_type == "nv_fp4_with_static_gs" - ar.quantize_and_save(self.save_folder) - model = transformers.AutoModelForCausalLM.from_pretrained(self.save_folder, trust_remote_code=True) + _, quantized_model_path = ar.quantize_and_save(self.save_folder) + model = transformers.AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True) assert model is not None, "Model loading failed after quantization with NVFP4 scheme" @pytest.mark.parametrize( @@ -198,24 +205,26 @@ def test_set_scheme(self, tiny_qwen_model_path): def test_fp8_static(self, tiny_opt_model_path): ar = AutoRound(tiny_opt_model_path, scheme="FP8_STATIC", nsamples=1, iters=1) + ar.post_init() assert ar.bits == 8 assert ar.act_bits == 8 assert ar.data_type == "fp" assert ar.act_data_type == "fp" assert ar.group_size == -1 assert ar.act_dynamic is False - ar.quantize_and_save() - model = transformers.AutoModelForCausalLM.from_pretrained("tmp_autoround", trust_remote_code=True) + _, quantized_model_path = ar.quantize_and_save() + model = transformers.AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True) assert model is not None, "Model loading failed after quantization with FP8_STATIC scheme" def test_fp8_static_rtn(self, tiny_opt_model_path): ar = AutoRound(tiny_opt_model_path, scheme="FP8_STATIC", nsamples=1, iters=0, disable_opt_rtn=True) + ar.post_init() assert ar.bits == 8 assert ar.act_bits == 8 assert ar.data_type == "fp" assert ar.act_data_type == "fp" assert ar.group_size == -1 assert ar.act_dynamic is False - ar.quantize_and_save(self.save_folder) - model = transformers.AutoModelForCausalLM.from_pretrained(self.save_folder, trust_remote_code=True) + _, quantized_model_path = ar.quantize_and_save(self.save_folder) + model = transformers.AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True) assert model is not None, "Model loading failed after quantization with FP8_STATIC scheme" diff --git a/test/test_cpu/utils/test_generation.py b/test/test_cpu/utils/test_generation.py index 448c3cc40..2acb560a2 100644 --- a/test/test_cpu/utils/test_generation.py +++ b/test/test_cpu/utils/test_generation.py @@ -43,7 +43,9 @@ def test_4bits_sym(self, dataloader): ) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round", inplace=False) + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round", inplace=False + ) model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) @@ -78,7 +80,7 @@ def test_autoround_sym(self, dataloader): ) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained( quantized_model_path, device_map="auto", trust_remote_code=True diff --git a/test/test_cuda/advanced/test_multiple_card.py b/test/test_cuda/advanced/test_multiple_card.py index ac6d0fb22..abc3bddb6 100644 --- a/test/test_cuda/advanced/test_multiple_card.py +++ b/test/test_cuda/advanced/test_multiple_card.py @@ -198,24 +198,24 @@ def test_device_map_for_triton(self): @multi_card def test_mllm_device_map(self, tiny_qwen_2_5_vl_model_path): - from auto_round import AutoRoundMLLM + from auto_round import AutoRound device_map = "0,1" - ar = AutoRoundMLLM(tiny_qwen_2_5_vl_model_path, device_map=device_map) + ar = AutoRound(tiny_qwen_2_5_vl_model_path, device_map=device_map) assert ar.device == "cuda:0" assert ar.device_map == device_map device_map = 1 - ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) + ar = AutoRound(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) assert ar.device == "cuda:1" assert ar.device_map == device_map device_map = "auto" - ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) + ar = AutoRound(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) assert ar.device == "cuda" assert ar.device_map == device_map device_map = {"model.language_model.layers": 0, "model.visual.blocks": 1} - ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) + ar = AutoRound(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) assert ar.model.model.language_model.layers[0].self_attn.q_proj.tuning_device == "cuda:0" assert ar.model.model.visual.blocks[0].mlp.gate_proj.tuning_device == "cuda:1" diff --git a/test/test_cuda/algorithms/test_alg_ext.py b/test/test_cuda/algorithms/test_alg_ext.py index 17cbe6d90..7aceb84fb 100644 --- a/test/test_cuda/algorithms/test_alg_ext.py +++ b/test/test_cuda/algorithms/test_alg_ext.py @@ -48,9 +48,9 @@ def test_all_support_dtype(self, scheme, tiny_qwen_model_path): def test_2bits(self): model_name = get_model_path("facebook/opt-125m") ar = AutoRound(model=model_name, bits=2, group_size=64, enable_alg_ext=True) - ar.quantize_and_save(self.save_folder) + _, quantized_model_path = ar.quantize_and_save(self.save_folder) model = AutoModelForCausalLM.from_pretrained( - self.save_folder, + quantized_model_path, device_map="auto", ) diff --git a/test/test_cuda/algorithms/test_auto_scheme.py b/test/test_cuda/algorithms/test_auto_scheme.py index 9cde148d7..cd9f85e73 100644 --- a/test/test_cuda/algorithms/test_auto_scheme.py +++ b/test/test_cuda/algorithms/test_auto_scheme.py @@ -264,16 +264,16 @@ def test_auto_scheme_export(self): model_name = get_model_path("facebook/opt-125m") scheme = AutoScheme(avg_bits=3, options=("W2A16", "W4A16", "W8A16", "BF16")) ar = AutoRound(model=model_name, scheme=scheme) - ar.quantize_and_save(output_dir=self.save_dir) - evaluate_accuracy(self.save_dir, threshold=0.25) + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir) + evaluate_accuracy(quantized_model_path, threshold=0.25) @pytest.mark.skip_ci(reason="The evaluation is time-consuming") def test_enable_torch_compile(self): model_name = get_model_path("facebook/opt-125m") scheme = AutoScheme(avg_bits=2, options=("W2A16"), ignore_scale_zp_bits=True) ar = AutoRound(model=model_name, scheme=scheme, enable_torch_compile=True) - ar.quantize_and_save(output_dir=self.save_dir) - evaluate_accuracy(self.save_dir, threshold=0.10) + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir) + evaluate_accuracy(quantized_model_path, threshold=0.10) def test_mixed_bits_get_scoring(self): """Verify that AutoScheme scoring produces accuracy above a known reference threshold for mixed-bit diff --git a/test/test_cuda/backends/test_exllamav2_backend.py b/test/test_cuda/backends/test_exllamav2_backend.py index 4ed812105..db3a15b63 100644 --- a/test/test_cuda/backends/test_exllamav2_backend.py +++ b/test/test_cuda/backends/test_exllamav2_backend.py @@ -41,23 +41,25 @@ def test_gptqmodel_exllmav2_4bits_asym(self, dataloader): model_path, bits=bits, group_size=group_size, sym=sym, iters=1, seqlen=2, dataset=dataloader ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:gptqmodel") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:gptqmodel" + ) quantization_config = AutoRoundConfig(backend="gptqmodel:exllamav2") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.35, batch_size=16) torch.cuda.empty_cache() model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.35, batch_size=16) torch.cuda.empty_cache() @@ -79,14 +81,16 @@ def test_gptq_exllamav2_4bits_sym(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") ##will convert to gptq model + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round" + ) ##will convert to gptq model quantization_config = AutoRoundConfig(backend="gptq:exllamav2") ## or exllamav2 model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.27, batch_size=16) torch.cuda.empty_cache() @@ -108,14 +112,16 @@ def test_gptq_exllamav2_4bits_sym_group_size(self, group_size): sym=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") ##will convert to gptq model + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round" + ) ##will convert to gptq model quantization_config = AutoRoundConfig(backend="gptq:exllamav2") ## or exllamav2 model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.15, batch_size=64) torch.cuda.empty_cache() @@ -134,15 +140,17 @@ def test_gptqmodel_awq_exllamav2_4bits_asym(self, dataloader): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) quantization_config = AutoRoundConfig(backend="gptqmodel:awq_exllamav2") # test awq bfloat16 inference model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) assert model.dtype == torch.bfloat16, f"Expected model dtype bfloat16, got {model.dtype}" - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) # Inference generation check eval_generated_prompt(model, tokenizer) # Accuracy check @@ -165,14 +173,16 @@ def test_gptqmodel_awq_exllamav2_4bits_sym(self, dataloader): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) quantization_config = AutoRoundConfig(backend="gptqmodel:awq_exllamav2") model = AutoModelForCausalLM.from_pretrained( # test awq bfloat16 inference - self.save_dir, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) # Inference generation check eval_generated_prompt(model, tokenizer) # Accuracy check diff --git a/test/test_cuda/backends/test_marlin_backend.py b/test/test_cuda/backends/test_marlin_backend.py index ff00315d4..95b71b02b 100644 --- a/test/test_cuda/backends/test_marlin_backend.py +++ b/test/test_cuda/backends/test_marlin_backend.py @@ -49,14 +49,14 @@ def test_marlin_4bits_sym_with_zp_m_1(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") quantization_config = AutoRoundConfig(backend="marlin") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.27, batch_size=16) torch.cuda.empty_cache() @@ -80,14 +80,17 @@ def test_marlin_group_size(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") quantization_config = AutoRoundConfig(backend="marlin") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=quantization_config, ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.14, batch_size=16) @@ -107,14 +110,17 @@ def test_marlin_group_size(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") quantization_config = AutoRoundConfig(backend="marlin") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=quantization_config, ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.14, batch_size=16) @@ -137,7 +143,7 @@ def test_marlin_group_size(self, dataloader): # # quantization_config = AutoRoundConfig(backend="marlin") # model = AutoModelForCausalLM.from_pretrained( - # self.save_dir, + # quantized_model_path, # torch_dtype=torch.float16, # device_map="auto", # quantization_config=quantization_config @@ -179,14 +185,16 @@ def test_gptqmodel_awq_marlin_4bits_sym(self): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) quantization_config = AutoRoundConfig(backend="gptqmodel:awq_marlin") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype="auto", device_map="cuda:0", quantization_config=quantization_config + quantized_model_path, torch_dtype="auto", device_map="cuda:0", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) # Inference generation check eval_generated_prompt(model, tokenizer) # Accuracy check @@ -211,14 +219,16 @@ def test_gptqmodel_awq_marlin_group_size(self, group_size): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) quantization_config = AutoRoundConfig(backend="gptqmodel:awq_marlin") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype="auto", device_map="cuda:0", quantization_config=quantization_config + quantized_model_path, torch_dtype="auto", device_map="cuda:0", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) # Inference generation check eval_generated_prompt(model, tokenizer) # Accuracy check diff --git a/test/test_cuda/backends/test_torch_backend.py b/test/test_cuda/backends/test_torch_backend.py index dfd2c85eb..71e743f14 100644 --- a/test/test_cuda/backends/test_torch_backend.py +++ b/test/test_cuda/backends/test_torch_backend.py @@ -49,14 +49,16 @@ def test_torch_4bits_asym(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:gptqmodel") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:gptqmodel" + ) quantization_config = AutoRoundConfig(backend="torch") model = AutoModelForCausalLM.from_pretrained( quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.35, batch_size=16) torch.cuda.empty_cache() @@ -79,13 +81,15 @@ def test_torch_4bits_sym(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") ##will convert to gptq model + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round" + ) ##will convert to gptq model quantization_config = AutoRoundConfig(backend="torch") model = AutoModelForCausalLM.from_pretrained( quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.28, batch_size=16) torch.cuda.empty_cache() @@ -130,7 +134,9 @@ def test_autoround_3bit_sym_torch_format(self, tiny_opt_model_path, dataloader): autoround.quantize() quantized_model_path = self.save_dir - autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round") + _, quantized_model_path = autoround.save_quantized( + output_dir=quantized_model_path, inplace=False, format="auto_round", return_folders=True + ) device = "auto" ##cpu, hpu, cuda from transformers import AutoRoundConfig @@ -160,14 +166,16 @@ def test_gptqmodel_awq_torch_4bits_group_size_16(self, dataloader): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) quantization_config = AutoRoundConfig(backend="gptqmodel:awq_torch") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) # Inference generation check output = model_infer(model, tokenizer) assert isinstance(output, str) and len(output.strip()) > 0, "Model failed to generate non-empty output" diff --git a/test/test_cuda/backends/test_triton_backend.py b/test/test_cuda/backends/test_triton_backend.py index 6675d1620..fa1c1f152 100644 --- a/test/test_cuda/backends/test_triton_backend.py +++ b/test/test_cuda/backends/test_triton_backend.py @@ -35,14 +35,14 @@ def test_tritonv2_2bits_asym(self): bits, group_size, sym = 2, 32, False autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path) quantization_config = AutoRoundConfig(backend="tritonv2") model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.19, batch_size=16) torch.cuda.empty_cache() @@ -65,14 +65,16 @@ def test_tritonv2_4bits_asym(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:gptqmodel") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:gptqmodel" + ) quantization_config = AutoRoundConfig(backend="tritonv2") model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.34, batch_size=16) torch.cuda.empty_cache() @@ -95,23 +97,23 @@ def test_tritonv2_4bits_sym(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path) quantization_config = AutoRoundConfig(backend="tritonv2") model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.26, batch_size=16) torch.cuda.empty_cache() model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.26, batch_size=16) torch.cuda.empty_cache() @@ -125,23 +127,23 @@ def test_tritonv2_8bits_sym(self): bits, group_size, sym = 4, 256, True autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, nsamples=1, iters=1) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path) quantization_config = AutoRoundConfig(backend="tritonv2") model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.27, batch_size=16) torch.cuda.empty_cache() model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.27, batch_size=16) torch.cuda.empty_cache() @@ -161,23 +163,23 @@ def test_tritonv2_2bits_sym(self): sym=sym, ) quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path) quantization_config = AutoRoundConfig(backend="tritonv2") model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.18, batch_size=16) torch.cuda.empty_cache() model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config + quantized_model_path, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.18, batch_size=16) torch.cuda.empty_cache() diff --git a/test/test_cuda/export/test_auto_awq_format.py b/test/test_cuda/export/test_auto_awq_format.py index 394b6ab9f..2ee2b9de4 100644 --- a/test/test_cuda/export/test_auto_awq_format.py +++ b/test/test_cuda/export/test_auto_awq_format.py @@ -70,7 +70,7 @@ def test_autoawq_format_fp_qsave_layers(self): layer_config=layer_config, ) quantized_model_path = os.path.join(self.save_dir, "test_export") - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_awq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_awq") # test loading with AutoRoundConfig model = AutoModelForCausalLM.from_pretrained( @@ -99,7 +99,7 @@ def test_fallback_regex_for_awq_format(self, tiny_opt_model_path, dataloader): layer_config=layer_config, ) quantized_model_path = "self.save_dir" - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_awq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_awq") quantization_config = AutoRoundConfig() model = AutoModelForCausalLM.from_pretrained( quantized_model_path, device_map="auto", quantization_config=quantization_config diff --git a/test/test_cuda/export/test_auto_gptq_format.py b/test/test_cuda/export/test_auto_gptq_format.py index 474e058d0..f7a2891cf 100644 --- a/test/test_cuda/export/test_auto_gptq_format.py +++ b/test/test_cuda/export/test_auto_gptq_format.py @@ -72,7 +72,7 @@ def test_autogptq_format_qsave_ignore_layers(self): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) diff --git a/test/test_cuda/export/test_auto_round_format.py b/test/test_cuda/export/test_auto_round_format.py index ba9862712..863dfc9e9 100644 --- a/test/test_cuda/export/test_auto_round_format.py +++ b/test/test_cuda/export/test_auto_round_format.py @@ -53,7 +53,7 @@ def test_autoround_format(self, tiny_opt_model_path, bits, group_size, is_sym): ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") # Verify loading model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cuda:0", trust_remote_code=True) @@ -74,7 +74,7 @@ def test_mixed_precision(self): bits, group_size, sym = 4, 128, True autoround = AutoRound(model_name, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") eval_generated_prompt(quantized_model_path) evaluate_accuracy(quantized_model_path, threshold=0.32, batch_size=16) @@ -92,23 +92,31 @@ def test_awq_backend(self): sym=sym, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) quantization_config = AutoRoundConfig(backend="auto") model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.float16, device_map="cuda:0", quantization_config=quantization_config + quantized_model_path, + torch_dtype=torch.float16, + device_map="cuda:0", + quantization_config=quantization_config, ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) eval_generated_prompt(model, tokenizer) evaluate_accuracy(model, tokenizer, threshold=0.18, batch_size=16) torch.cuda.empty_cache() model = AutoModelForCausalLM.from_pretrained( - self.save_dir, torch_dtype=torch.bfloat16, device_map="cuda:0", quantization_config=quantization_config + quantized_model_path, + torch_dtype=torch.bfloat16, + device_map="cuda:0", + quantization_config=quantization_config, ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) eval_generated_prompt(model, tokenizer) @pytest.mark.skip_ci(reason="Time-consuming; Accuracy evaluation") @@ -138,7 +146,7 @@ def test_autoround_gptq_sym_format(self, tiny_opt_model_path, dataloader): ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path) from transformers import AutoRoundConfig @@ -188,7 +196,9 @@ def test_autoround_awq_sym_format(self, tiny_opt_model_path, dataloader): ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) diff --git a/test/test_cuda/export/test_gguf_format.py b/test/test_cuda/export/test_gguf_format.py index 1e413970f..fb4069048 100644 --- a/test/test_cuda/export/test_gguf_format.py +++ b/test/test_cuda/export/test_gguf_format.py @@ -132,7 +132,7 @@ def test_special_model(self): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") file_name = os.listdir(quantized_model_path)[0] file_size = os.path.getsize(os.path.join(quantized_model_path, file_name)) / 1024**2 assert abs(file_size - 307) < 5.0 @@ -169,11 +169,11 @@ def test_vlm_gguf(self): quant_nontext_module=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m") - assert "mmproj-model.gguf" in os.listdir(self.save_dir) - for file in os.listdir(self.save_dir): - print(f"{file}: {os.path.getsize(os.path.join(self.save_dir, file)) / 1024**2} MB") - file_size = os.path.getsize(os.path.join(self.save_dir, file)) / 1024**2 + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m") + assert "mmproj-model.gguf" in os.listdir(quantized_model_path) + for file in os.listdir(quantized_model_path): + print(f"{file}: {os.path.getsize(os.path.join(quantized_model_path, file)) / 1024**2} MB") + file_size = os.path.getsize(os.path.join(quantized_model_path, file)) / 1024**2 if "mmproj-model.gguf" in file: assert abs(file_size - 75) < 5.0 else: @@ -198,7 +198,7 @@ def test_q2k_mixed(self): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q2_k_mixed") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q2_k_mixed") gguf_file = os.listdir(quantized_model_path)[0] file_size = os.path.getsize(os.path.join(quantized_model_path, gguf_file)) / 1024**2 assert abs(file_size - 1236) < 5.0 @@ -229,7 +229,7 @@ def test_q2_k_s_ffn_down_q4k(self): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q2_k_s") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q2_k_s") gguf_file = os.listdir(quantized_model_path)[0] gguf_model = GGUFReader(os.path.join(quantized_model_path, gguf_file)) ffn_down_type = None @@ -258,5 +258,7 @@ def test_gguf_baseline(self): disable_opt_rtn=True, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, inplace=False, format="fake") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=False, format="fake" + ) eval_generated_prompt(quantized_model_path) diff --git a/test/test_cuda/export/test_llmc_format.py b/test/test_cuda/export/test_llmc_format.py index 232bbb94b..e15223f83 100644 --- a/test/test_cuda/export/test_llmc_format.py +++ b/test/test_cuda/export/test_llmc_format.py @@ -7,8 +7,11 @@ from auto_round import AutoRound from auto_round import schemes as ar_schemes +from ...envs import is_compressed_tensors_available from ...helpers import eval_generated_prompt, get_model_path, is_cuda_support_fp8 +pytestmark = pytest.mark.skipif(not is_compressed_tensors_available(), reason="test requires compressed-tensors") + class TestAutoRound: @@ -41,7 +44,7 @@ def test_fp8input_mxfp4_llmcompressor_format(self, dataloader, tiny_fp8_qwen_mod dataset=dataloader, ) print(ar.model) - compressed_model, _ = ar.quantize_and_save(output_dir=self.save_dir, format="llm_compressor") + compressed_model, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir, format="llm_compressor") tmp_layer = compressed_model.model.layers[1].self_attn.q_proj assert ( hasattr(tmp_layer, "weight_scale") @@ -50,7 +53,7 @@ def test_fp8input_mxfp4_llmcompressor_format(self, dataloader, tiny_fp8_qwen_mod and tmp_layer.weight_scale.shape[0] == 2048 ), "Illegal MXFP4 packing name or data_type or shape" quantization_config = transformers.AutoConfig.from_pretrained( - self.save_dir, trust_remote_code=True + quantized_model_path, trust_remote_code=True ).quantization_config assert ( quantization_config["format"] == "mxfp4-pack-quantized" @@ -68,7 +71,9 @@ def test_nvfp4_llmcompressor_format(self, tiny_opt_model_path, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") + compressed_model, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="llm_compressor" + ) tmp_layer = compressed_model.model.decoder.layers[1].self_attn.q_proj assert ( hasattr(tmp_layer, "weight_scale") @@ -98,7 +103,9 @@ def test_fp8_block_llm_compressor_format(self, tiny_qwen_model_path, dataloader) disable_opt_rtn=True, ) quantized_model_path = self.save_dir - compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") + compressed_model, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="llm_compressor" + ) tmp_layer = compressed_model.model.layers[1].self_attn.q_proj assert hasattr(tmp_layer, "weight_scale") assert tmp_layer.weight.dtype is torch.float8_e4m3fn diff --git a/test/test_cuda/integrations/test_sglang.py b/test/test_cuda/integrations/test_sglang.py index 2bee14bd5..196fa2efa 100644 --- a/test/test_cuda/integrations/test_sglang.py +++ b/test/test_cuda/integrations/test_sglang.py @@ -10,11 +10,11 @@ from auto_round import AutoRound -from ...helpers import get_model_path, opt_name_or_path +from ...helpers import get_model_path, qwen_name_or_path class TestAutoRound: - model_name = opt_name_or_path + model_name = qwen_name_or_path @pytest.fixture(autouse=True) def _save_dir(self, tmp_path): @@ -35,6 +35,20 @@ def setup_and_teardown_class(self): shutil.rmtree("runs", ignore_errors=True) def _run_sglang_inference(self, model_path: Path): + # SM 12.x (Blackwell) GPUs require CUDA >= 12.9 for sglang's gptq_marlin_repack JIT kernel. + # Skip inference when the environment is known to be incompatible. + if torch.cuda.is_available(): + try: + major, minor = torch.cuda.get_device_capability() + if major >= 12: + cuda_ver = tuple(int(x) for x in (torch.version.cuda or "0.0").split(".")[:2]) + if cuda_ver < (12, 9): + pytest.skip( + f"SM {major}.{minor} GPU requires CUDA >= 12.9 for sglang GPTQ JIT kernels " + f"(installed: CUDA {torch.version.cuda})" + ) + except Exception: + pass llm = sgl.Engine( model_path=str(model_path), mem_fraction_static=0.5, disable_piecewise_cuda_graph=True, cuda_graph_bs=[1] ) @@ -58,13 +72,13 @@ def test_ar_format_sglang(self, dataloader): dataset=dataloader, ) - autoround.quantize_and_save( + _, quantized_model_path = autoround.quantize_and_save( output_dir=self.save_dir, inplace=True, format="auto_round", ) - generated_text = self._run_sglang_inference(self.save_dir) + generated_text = self._run_sglang_inference(quantized_model_path) print(generated_text) assert "!!!" not in generated_text @@ -73,7 +87,7 @@ def test_mixed_ar_format_sglang(self, dataloader): layer_config = { "self_attn": {"bits": 8}, "lm_head": {"bits": 16}, - "fc1": {"bits": 16, "act_bits": 16}, + "mlp": {"bits": 16, "act_bits": 16}, } autoround = AutoRound( @@ -85,22 +99,22 @@ def test_mixed_ar_format_sglang(self, dataloader): layer_config=layer_config, ) - autoround.quantize_and_save( + _, quantized_model_path = autoround.quantize_and_save( output_dir=self.save_dir, inplace=True, format="auto_round", ) - config_file = Path(self.save_dir) / "config.json" + config_file = Path(quantized_model_path) / "config.json" with open(config_file, "r", encoding="utf-8") as f: config = json.load(f) quant_config = config.get("quantization_config", {}) extra_config = quant_config.get("extra_config", {}) # check extra_config only saved attributes differing from Scheme values - assert "act_bits" not in extra_config[".*fc1.*"].keys() - assert "group_size" not in extra_config[".*fc1.*"].keys() - assert "bits" in extra_config[".*fc1.*"].keys() and extra_config[".*fc1.*"]["bits"] == 16 + assert "act_bits" not in extra_config[".*mlp.*"].keys() + assert "group_size" not in extra_config[".*mlp.*"].keys() + assert "bits" in extra_config[".*mlp.*"].keys() and extra_config[".*mlp.*"]["bits"] == 16 assert "bits" in extra_config[".*self_attn.*"].keys() and extra_config[".*self_attn.*"]["bits"] == 8 - generated_text = self._run_sglang_inference(self.save_dir) + generated_text = self._run_sglang_inference(quantized_model_path) print(generated_text) assert "!!!" not in generated_text @@ -117,13 +131,13 @@ def test_awq_format_sglang(self, dataloader): dataset=dataloader, ) - autoround.quantize_and_save( + _, quantized_model_path = autoround.quantize_and_save( output_dir=self.save_dir, inplace=True, format="auto_round:auto_awq", ) - generated_text = self._run_sglang_inference(self.save_dir) + generated_text = self._run_sglang_inference(quantized_model_path) print(generated_text) assert "!!!" not in generated_text diff --git a/test/test_cuda/integrations/test_vllm.py b/test/test_cuda/integrations/test_vllm.py index 3cef719e5..e9faa3e14 100644 --- a/test/test_cuda/integrations/test_vllm.py +++ b/test/test_cuda/integrations/test_vllm.py @@ -21,6 +21,31 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +def _is_sm12_with_old_cuda() -> bool: + """Return True when the GPU is SM 12.x (Blackwell) and CUDA < 12.9. + + vLLM's gptq_marlin JIT kernels require CUDA >= 12.9 on SM 12.x devices. + """ + try: + import torch + + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + if major < 12: + return False + cuda_ver = tuple(int(x) for x in (torch.version.cuda or "0.0").split(".")[:2]) + return cuda_ver < (12, 9) + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + _is_sm12_with_old_cuda(), + reason="SM 12.x (Blackwell) GPU requires CUDA >= 12.9 for vLLM GPTQ marlin JIT kernels", +) + MODELS = [ "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound", ##auto_round:auto_awq @@ -100,7 +125,7 @@ def test_mixed_llmcompressor_format_vllm(tiny_opt_model_path, dataloader, tmp_pa layer_config=layer_config, ) quantized_model_path = str(tmp_path / "saved") - autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor") # verify loading. llm = LLM( @@ -165,9 +190,9 @@ def test_auto_round_awq_format_vllm(): iters=1, seqlen=2, ) - autoround.quantize_and_save(output_dir=save_dir, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save(output_dir=save_dir, format="auto_round:auto_awq") sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=32) - llm = LLM(model=save_dir, trust_remote_code=True, tensor_parallel_size=1, gpu_memory_utilization=0.7) + llm = LLM(model=quantized_model_path, trust_remote_code=True, tensor_parallel_size=1, gpu_memory_utilization=0.7) outputs = llm.generate(["The capital of France is"], sampling_params) generated_text = outputs[0].outputs[0].text print(generated_text) diff --git a/test/test_cuda/models/test_fp8_model.py b/test/test_cuda/models/test_fp8_model.py index e8343b0b2..0bbbba1f8 100644 --- a/test/test_cuda/models/test_fp8_model.py +++ b/test/test_cuda/models/test_fp8_model.py @@ -32,14 +32,14 @@ def setup_and_teardown_class(self): def test_small_model_rtn_generation(self, mock_fp8_capable_device, tiny_fp8_qwen_model_path): ar = AutoRound(tiny_fp8_qwen_model_path, iters=0, disable_opt_rtn=True) - ar.quantize_and_save(output_dir=self.save_dir) - model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir) + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) generate_prompt(model, tokenizer) def test_gguf_imatrix(self, mock_fp8_capable_device, tiny_fp8_qwen_model_path): ar = AutoRound(tiny_fp8_qwen_model_path, iters=0) - ar.quantize_and_save(format="gguf:q2_k_s", output_dir=self.save_dir) + _, quantized_model_path = ar.quantize_and_save(format="gguf:q2_k_s", output_dir=self.save_dir) # from llama_cpp import Llama # # gguf_file = os.listdir("saved/Qwen3-0.6B-FP8/-gguf")[0] @@ -47,8 +47,8 @@ def test_gguf_imatrix(self, mock_fp8_capable_device, tiny_fp8_qwen_model_path): # output = llm("There is a girl who likes adventure,", max_tokens=32) # print(output) # shutil.rmtree("./saved", ignore_errors=True) - # model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + # model = AutoModelForCausalLM.from_pretrained(quantized_model_path, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path ) # text = "There is a girl who likes adventure," # inputs = tokenizer(text, return_tensors="pt").to(model.device) # print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) @@ -58,21 +58,21 @@ def test_small_model_rtn(self, mock_fp8_capable_device): model_name = get_model_path("Qwen/Qwen3-0.6B-FP8") ar = AutoRound(model=model_name, iters=0) _, folder = ar.quantize_and_save(output_dir=self.save_dir) - evaluate_accuracy(self.save_dir, threshold=0.25) + evaluate_accuracy(folder, threshold=0.25) @pytest.mark.skip_ci(reason="Triton issue; time-consuming") def test_small_model_iters1(self, mock_fp8_capable_device): model_name = get_model_path("Qwen/Qwen3-0.6B-FP8") ar = AutoRound(model=model_name, iters=1) _, folder = ar.quantize_and_save(output_dir=self.save_dir) - evaluate_accuracy(self.save_dir, threshold=0.25) + evaluate_accuracy(folder, threshold=0.25) @pytest.mark.skip_ci(reason="Triton issue; time-consuming") def test_medium_model_rtn(self, mock_fp8_capable_device): model_name = get_model_path("Qwen/Qwen3-0.6B-FP8") ar = AutoRound(model=model_name, iters=0) _, folder = ar.quantize_and_save(output_dir=self.save_dir) - evaluate_accuracy(self.save_dir, threshold=0.33) + evaluate_accuracy(folder, threshold=0.33) @pytest.mark.skip_ci(reason="Triton issue; time-consuming") def test_medium_model_rtn_with_lm_head(self, mock_fp8_capable_device): @@ -80,17 +80,17 @@ def test_medium_model_rtn_with_lm_head(self, mock_fp8_capable_device): layer_config = {"lm_head": {"bits": 4}} ar = AutoRound(model=model_name, iters=0, layer_config=layer_config) _, folder = ar.quantize_and_save(output_dir=self.save_dir) - evaluate_accuracy(self.save_dir, threshold=0.33) + evaluate_accuracy(folder, threshold=0.33) def test_fp8_model_gguf_q4(self, mock_fp8_capable_device, tiny_fp8_qwen_model_path): from llama_cpp import Llama ar = AutoRound(tiny_fp8_qwen_model_path, iters=0, disable_opt_rtn=True) - ar.quantize_and_save(output_dir=self.save_dir, format="gguf:q4_0") - for file in os.listdir(self.save_dir): + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir, format="gguf:q4_0") + for file in os.listdir(quantized_model_path): if file.endswith(".gguf"): gguf_file = file - llm = Llama(f"{self.save_dir}/{gguf_file}", n_gpu_layers=-1) + llm = Llama(f"{quantized_model_path}/{gguf_file}", n_gpu_layers=-1) output = llm("There is a girl who likes adventure,", max_tokens=32) print(output) @@ -99,11 +99,11 @@ def test_fp8_model_gguf_q3(self, mock_fp8_capable_device, tiny_fp8_qwen_model_pa from llama_cpp import Llama ar = AutoRound(tiny_fp8_qwen_model_path, iters=1) - ar.quantize_and_save(output_dir=self.save_dir, format="gguf:q3_k_s") - for file in os.listdir(self.save_dir): + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir, format="gguf:q3_k_s") + for file in os.listdir(quantized_model_path): if file.endswith(".gguf"): gguf_file = file - llm = Llama(f"{self.save_dir}/{gguf_file}", n_gpu_layers=-1) + llm = Llama(f"{quantized_model_path}/{gguf_file}", n_gpu_layers=-1) output = llm("There is a girl who likes adventure,", max_tokens=32) print(output) @@ -113,8 +113,8 @@ def test_diff_datatype(self, scheme, tiny_fp8_qwen_model_path, mock_fp8_capable_ model_name = tiny_fp8_qwen_model_path print(f"Testing scheme: {scheme}") ar = AutoRound(model_name, iters=0, scheme=scheme, disable_opt_rtn=True, nsamples=2) - ar.quantize_and_save(output_dir=self.save_dir) - model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", trust_remote_code=True) + _, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir) + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, torch_dtype="auto", trust_remote_code=True) assert model is not None, f"Failed to load model for scheme {scheme}" @@ -128,9 +128,9 @@ def test_qwen3_fp8_moe_mxfp(tiny_fp8_qwen_moe_model_path, mock_fp8_capable_devic iters=0, low_cpu_mem_usage=False, ) - quantized_model, _ = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + quantized_model, quantized_model_path = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) assert quantized_model is not None, "Quantized model should not be None." - loaded_model = AutoModelForCausalLM.from_pretrained(output_dir) + loaded_model = AutoModelForCausalLM.from_pretrained(quantized_model_path) for n, m in quantized_model.named_modules(): if m.__class__.__name__ == "QuantLinear": loaded_m = loaded_model.get_submodule(n) diff --git a/test/test_cuda/models/test_mllm.py b/test/test_cuda/models/test_mllm.py index 52d09175c..ba79f77f8 100644 --- a/test/test_cuda/models/test_mllm.py +++ b/test/test_cuda/models/test_mllm.py @@ -8,7 +8,7 @@ from PIL import Image from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration -from auto_round import AutoRoundMLLM +from auto_round import AutoRound from auto_round.utils import get_block_names from ...envs import require_gptqmodel, require_optimum, require_vlm_env @@ -16,6 +16,7 @@ class VisionDataLoader: + def __init__(self): self.batch_size = 1 @@ -37,7 +38,8 @@ def __iter__(self): @pytest.mark.skip_ci(reason="Only tiny model is suggested") -class TestAutoRoundMLLM: +class TestAutoRound: + @pytest.fixture(autouse=True) def _save_dir(self, tmp_path): self.save_dir = str(tmp_path / "saved") @@ -112,13 +114,13 @@ def qwen_inference(self, quantized_model_dir): @require_gptqmodel @require_optimum def test_vlm_tune(self): - from auto_round import AutoRoundMLLM + from auto_round import AutoRound ## load the model model_name = get_model_path("Qwen/Qwen2-VL-2B-Instruct") ## quantize the model bits, group_size, sym = 4, 128, True - autoround = AutoRoundMLLM(model_name, bits=bits, group_size=group_size, sym=sym, iters=1, nsamples=1) + autoround = AutoRound(model_name, bits=bits, group_size=group_size, sym=sym, iters=1, nsamples=1) autoround.quantize() quantized_model_path = self.save_dir @@ -181,7 +183,7 @@ def test_llama32_vision_early_stop_tracking(self): model_path, trust_remote_code=True, device_map="auto", torch_dtype="auto" ) - autoround = AutoRoundMLLM( + autoround = AutoRound( model=model, tokenizer=tokenizer, processor=processor, diff --git a/test/test_cuda/models/test_moe_model.py b/test/test_cuda/models/test_moe_model.py index 95c61bd55..8ba5e871d 100644 --- a/test/test_cuda/models/test_moe_model.py +++ b/test/test_cuda/models/test_moe_model.py @@ -22,10 +22,10 @@ def test_qwen3_5_moe(tiny_qwen35_moe_model_path): seqlen=32, iters=1, ) - quantized_model, _ = ar.quantize_and_save(format="auto_round", output_dir=output_dir) + quantized_model, quantized_model_path = ar.quantize_and_save(format="auto_round", output_dir=output_dir) assert quantized_model is not None, "Quantized model should not be None." - loaded_model = Qwen3_5MoeForConditionalGeneration.from_pretrained(output_dir) + loaded_model = Qwen3_5MoeForConditionalGeneration.from_pretrained(quantized_model_path) loaded_model.to("cuda") inp = torch.randint(0, 100, (1, 64)).to("cuda") diff --git a/test/test_cuda/models/test_omni_model.py b/test/test_cuda/models/test_omni_model.py index b4ccdb81c..d74a4e529 100644 --- a/test/test_cuda/models/test_omni_model.py +++ b/test/test_cuda/models/test_omni_model.py @@ -73,10 +73,10 @@ def test_quantize_and_reload(self, tiny_qwen2_5_omni_model_path, tmp_path): for extra_file in ["spk_dict.pt"]: src = os.path.join(tiny_qwen2_5_omni_model_path, extra_file) if os.path.exists(src): - shutil.copy2(src, tmp_path) + shutil.copy2(src, save_folder) # Reload - loaded_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(tmp_path, device_map="cuda") + loaded_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(save_folder, device_map="cuda") # Run inference on thinker inp = torch.randint(0, 100, (1, 64)).to("cuda") @@ -111,7 +111,7 @@ def test_quantize_and_reload(self, tiny_qwen3_omni_moe_model_path): assert quantized_model is not None, "Quantized model should not be None" # Reload - loaded_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(self.save_dir, device_map="cuda") + loaded_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(save_folder, device_map="cuda") # Run inference on thinker inp = torch.randint(0, 100, (1, 64)).to("cuda") @@ -134,7 +134,7 @@ def test_quantize_mxfp4(self, tiny_qwen3_omni_moe_model_path): assert quantized_model is not None, "MXFP4 quantized model should not be None" # Reload and inference - loaded_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(self.save_dir, device_map="cuda") + loaded_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(save_folder, device_map="cuda") inp = torch.randint(0, 100, (1, 64)).to("cuda") with torch.inference_mode(): diff --git a/test/test_cuda/quantization/test_asym.py b/test/test_cuda/quantization/test_asym.py index 5145cf13d..97ac64288 100644 --- a/test/test_cuda/quantization/test_asym.py +++ b/test/test_cuda/quantization/test_asym.py @@ -29,15 +29,15 @@ def setup_and_teardown_class(self): def test_asym_group_size_with_tuning(self, group_size, tiny_opt_model_path): bits, sym = 4, False ar = AutoRound(tiny_opt_model_path, bits=bits, group_size=group_size, sym=sym, iters=1, seqlen=2, nsamples=1) - ar.quantize_and_save(format="auto_round", output_dir=self.save_dir) + _, quantized_model_path = ar.quantize_and_save(format="auto_round", output_dir=self.save_dir) model = AutoModelForCausalLM.from_pretrained( - self.save_dir, + quantized_model_path, torch_dtype="auto", device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) @pytest.mark.skip_ci(reason="Not necessary since it's covered by backend tests") # skip this test in CI @@ -45,15 +45,15 @@ def test_asym_group_size_with_tuning(self, group_size, tiny_opt_model_path): def test_asym_bits_with_tuning(self, bits, tiny_opt_model_path): group_size, sym = 128, False ar = AutoRound(tiny_opt_model_path, bits=bits, group_size=group_size, sym=sym, iters=1, seqlen=2, nsamples=1) - ar.quantize_and_save(format="auto_round", output_dir=self.save_dir) + _, quantized_model_path = ar.quantize_and_save(format="auto_round", output_dir=self.save_dir) model = AutoModelForCausalLM.from_pretrained( - self.save_dir, + quantized_model_path, torch_dtype="auto", device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) @pytest.mark.skip_ci(reason="Not necessary since it's covered by backend tests") # skip this test in CI @@ -61,17 +61,17 @@ def test_asym_bits_with_tuning(self, bits, tiny_opt_model_path): def test_asym_format_with_tuning(self, format, tiny_opt_model_path): bits, group_size, sym = 4, 128, False ar = AutoRound(tiny_opt_model_path, bits=bits, group_size=group_size, sym=sym, iters=1, seqlen=2, nsamples=1) - ar.quantize_and_save(format=format, output_dir=self.save_dir) + _, quantized_model_path = ar.quantize_and_save(format=format, output_dir=self.save_dir) if format == "auto_round:auto_gptq": # Cannot load correctly, skip auto_gptq since it's deprecated. return model = AutoModelForCausalLM.from_pretrained( - self.save_dir, + quantized_model_path, torch_dtype="auto", device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) model_infer(model, tokenizer) diff --git a/test/test_cuda/quantization/test_mxfp_nvfp.py b/test/test_cuda/quantization/test_mxfp_nvfp.py index 3e40e7288..bdc79075e 100644 --- a/test/test_cuda/quantization/test_mxfp_nvfp.py +++ b/test/test_cuda/quantization/test_mxfp_nvfp.py @@ -55,7 +55,7 @@ def test_e2e_quant_and_infer(scheme, tiny_qwen_model_path): # Quantize and save the model to the temporary directory quantized_model_path = f"{temp_dir}/tmp_autoround_{scheme}" - autoround.quantize_and_save(format="auto_round", output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(format="auto_round", output_dir=quantized_model_path) # Perform inference with the quantized model model = AutoModelForCausalLM.from_pretrained( @@ -144,7 +144,9 @@ def test_qwen_moe_quant_infer(self, dataloader): layer_config=layer_config, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, inplace=False, format="auto_round") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=False, format="auto_round" + ) model = AutoModelForCausalLM.from_pretrained(quantized_model_path, torch_dtype="auto", device_map="auto") tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) from ...helpers import evaluate_accuracy diff --git a/test/test_cuda/quantization/test_torch_compile.py b/test/test_cuda/quantization/test_torch_compile.py index 620ef75b0..e7efe0bf2 100644 --- a/test/test_cuda/quantization/test_torch_compile.py +++ b/test/test_cuda/quantization/test_torch_compile.py @@ -64,9 +64,9 @@ def test_gguf_q2ks_torch_compile_iters0(self, tiny_qwen_model_path): seqlen=16, enable_torch_compile=True, ) - autoround.quantize_and_save(output_dir=self.save_dir, format="gguf:q2_k_s") + _, quantized_model_path = autoround.quantize_and_save(output_dir=self.save_dir, format="gguf:q2_k_s") - saved_files = [f for f in os.listdir(self.save_dir) if f.endswith(".gguf")] + saved_files = [f for f in os.listdir(quantized_model_path) if f.endswith(".gguf")] assert len(saved_files) > 0, "No GGUF file was generated" shutil.rmtree(self.save_dir, ignore_errors=True) diff --git a/test/test_cuda/transform/test_mxfp4_transform.py b/test/test_cuda/transform/test_mxfp4_transform.py index 16511edad..5df09db9f 100644 --- a/test/test_cuda/transform/test_mxfp4_transform.py +++ b/test/test_cuda/transform/test_mxfp4_transform.py @@ -38,10 +38,10 @@ def test_transform_mxfp4_quant_infer(self): scheme=scheme, rotation_config="default", ) - compressed_model, _ = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") + compressed_model, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") - model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", device_map="cuda") - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, torch_dtype="auto", device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) from ...helpers import generate_prompt generate_prompt(model, tokenizer) @@ -57,10 +57,10 @@ def test_transform_mxfp4_tuning_quant_infer(self): scheme=scheme, rotation_config="default", ) - compressed_model, _ = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") + compressed_model, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") - model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", device_map="cuda") - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, torch_dtype="auto", device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) from ...helpers import generate_prompt generate_prompt(model, tokenizer) @@ -76,10 +76,10 @@ def test_random_transform_mxfp4_quant_infer(self): scheme=scheme, rotation_config="random_hadamard", ) - compressed_model, _ = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") + compressed_model, quantized_model_path = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") - model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", device_map="cuda") - tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, torch_dtype="auto", device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) from ...helpers import generate_prompt generate_prompt(model, tokenizer) diff --git a/test/test_xpu/test_autoround.py b/test/test_xpu/test_autoround.py index 4c4c63678..d03c71444 100644 --- a/test/test_xpu/test_autoround.py +++ b/test/test_xpu/test_autoround.py @@ -12,6 +12,7 @@ class TestAutoRoundXPU: + @classmethod def setup_class(self): self.device = "xpu" @@ -47,7 +48,7 @@ def test_gptq_format(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path) + _, quantized_model_path = autoround.quantize_and_save(output_dir=quantized_model_path) quantization_config = AutoRoundConfig(backend="auto") model = AutoModelForCausalLM.from_pretrained( @@ -78,7 +79,9 @@ def test_awq_format(self, dataloader): dataset=dataloader, ) quantized_model_path = self.save_dir - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:auto_awq") + _, quantized_model_path = autoround.quantize_and_save( + output_dir=quantized_model_path, format="auto_round:auto_awq" + ) quantization_config = AutoRoundConfig(backend="auto") model = AutoModelForCausalLM.from_pretrained( @@ -109,7 +112,9 @@ def test_scheme(self, scheme, dataloader): dataset=dataloader, ) quantized_model_path = "./saved" - ar.quantize_and_save(output_dir=quantized_model_path, inplace=True, format="auto_round") + _, quantized_model_path = ar.quantize_and_save( + output_dir=quantized_model_path, inplace=True, format="auto_round" + ) # test loading if scheme not in ["FPW8A16"]: # FPW8A16 group_size is 0 @@ -140,7 +145,9 @@ def test_vlm_model(self, dataloader): ) quantized_model_path = "./saved" - ar.quantize_and_save(output_dir=quantized_model_path, inplace=True, format="auto_round") + _, quantized_model_path = ar.quantize_and_save( + output_dir=quantized_model_path, inplace=True, format="auto_round" + ) quantization_config = AutoRoundConfig(backend="auto") import requests @@ -211,7 +218,9 @@ def test_quant_lm_head(self, dataloader): dataset=dataloader, ) quantized_model_path = "./saved" - ar.quantize_and_save(output_dir=quantized_model_path, inplace=True, format="auto_round") + _, quantized_model_path = ar.quantize_and_save( + output_dir=quantized_model_path, inplace=True, format="auto_round" + ) quantization_config = AutoRoundConfig(backend="auto") model = AutoModelForCausalLM.from_pretrained(