diff --git a/auto_round/autoround.py b/auto_round/autoround.py index b2bec2651..b1844490f 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -163,6 +163,8 @@ def __new__( """ local_args = {k: v for k, v in locals().items() if k not in cls.SKIP_ARGS} + if extra_config is not None: + local_args.update({k: v for k, v in extra_config.to_dict().items() if k in local_args and v is not None}) if NEW_ARCH: from auto_round.compressors_new.entry import AutoRoundCompatible diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 0503c8235..6eb1cfd47 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -94,6 +94,7 @@ get_layer_names_in_block, get_lm_head_name, get_module, + get_reverse_checkpoint_conversion_mapping, global_state, hook_ngram_embeddings_on_cpu, htcore, @@ -107,6 +108,7 @@ memory_monitor, mv_module_from_gpu, normalize_no_split_modules, + revert_checkpoint_conversion_mapping, set_amax_for_all_moe_layers, set_module, to_device, @@ -3603,6 +3605,21 @@ def save_quantized( serialization_dict["autoround_version"] = __version__ if "scale_dtype" in serialization_dict.keys(): serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) + + # to match the original name + reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model) + + if isinstance(serialization_dict["to_quant_block_names"], str): + serialization_dict["to_quant_block_names"] = revert_checkpoint_conversion_mapping( + serialization_dict["to_quant_block_names"], reverse_checkpoint_conversion_mapping + ) + + elif isinstance(serialization_dict["to_quant_block_names"], list): + for idx in range(len(serialization_dict["to_quant_block_names"])): + serialization_dict["to_quant_block_names"][idx] = revert_checkpoint_conversion_mapping( + serialization_dict["to_quant_block_names"][idx], reverse_checkpoint_conversion_mapping + ) + compressed_model = format.save_quantized( save_folder, model=self.model, diff --git a/auto_round/compressors/shard_writer.py b/auto_round/compressors/shard_writer.py index af3f510a3..77b3b7ea4 100644 --- a/auto_round/compressors/shard_writer.py +++ b/auto_round/compressors/shard_writer.py @@ -20,7 +20,12 @@ from torch.nn import Parameter from auto_round.logger import logger -from auto_round.utils import get_lm_head_name, get_module +from auto_round.utils import ( + get_lm_head_name, + get_module, + get_reverse_checkpoint_conversion_mapping, + revert_checkpoint_conversion_mapping, +) class ShardWriter: @@ -55,6 +60,7 @@ def __init__(self, rounder): self.shard_meta = [] # List of {tmp_file: str, params: list} self.global_weight_map = {} self.shard_counter = 0 + self.reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model) # Persistent set of all parameter names already flushed to a shard file. # Maintained incrementally in _flush_shard to avoid O(N^2) rebuilds in _add_tensor. @@ -105,6 +111,10 @@ def save_module(self, m: torch.nn.Module, name: str = None): self._add_tensor(param_name, v) def _add_tensor(self, name: str, tensor: torch.Tensor): + + # transformers will handle _checkpoint_conversion_mapping automatically if is_immediate_saving=False + name = revert_checkpoint_conversion_mapping(name, self.reverse_checkpoint_conversion_mapping) + if isinstance(tensor, torch.Tensor) and tensor.device.type == "meta": self.skipped_meta_tensors.append(name) return diff --git a/auto_round/compressors_new/base.py b/auto_round/compressors_new/base.py index 025dfbd3f..e8dc9d135 100644 --- a/auto_round/compressors_new/base.py +++ b/auto_round/compressors_new/base.py @@ -52,10 +52,12 @@ extract_block_names_to_str, find_matching_blocks, get_block_names, + get_reverse_checkpoint_conversion_mapping, is_debug_mode, is_hpex_available, is_quantized_input_module, memory_monitor, + revert_checkpoint_conversion_mapping, ) from auto_round.utils.device import ( _force_trim_malloc, @@ -1170,6 +1172,20 @@ def save_quantized( if "scale_dtype" in serialization_dict.keys(): serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) + # to match the original name + reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model) + + if isinstance(serialization_dict["to_quant_block_names"], str): + serialization_dict["to_quant_block_names"] = revert_checkpoint_conversion_mapping( + serialization_dict["to_quant_block_names"], reverse_checkpoint_conversion_mapping + ) + + elif isinstance(serialization_dict["to_quant_block_names"], list): + for idx in range(len(serialization_dict["to_quant_block_names"])): + serialization_dict["to_quant_block_names"][idx] = revert_checkpoint_conversion_mapping( + serialization_dict["to_quant_block_names"][idx], reverse_checkpoint_conversion_mapping + ) + compressed_model = format.save_quantized( save_folder, model=self.model_context.model, diff --git a/auto_round/compressors_new/shard_writer.py b/auto_round/compressors_new/shard_writer.py index dbdd2cc86..4a446ad07 100644 --- a/auto_round/compressors_new/shard_writer.py +++ b/auto_round/compressors_new/shard_writer.py @@ -22,7 +22,12 @@ from auto_round.context.compress import CompressContext from auto_round.context.model import ModelContext from auto_round.logger import logger -from auto_round.utils import get_lm_head_name, get_module +from auto_round.utils import ( + get_lm_head_name, + get_module, + get_reverse_checkpoint_conversion_mapping, + revert_checkpoint_conversion_mapping, +) class ShardWriter: @@ -78,6 +83,7 @@ def __init__( self.shard_meta = [] # List of {tmp_file: str, params: list} self.global_weight_map = {} self.shard_counter = 0 + self.reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model) # Persistent set of all parameter names already flushed to a shard file. # Maintained incrementally in _flush_shard to avoid O(N^2) rebuilds in _add_tensor. @@ -154,6 +160,10 @@ def save_module(self, m: torch.nn.Module, name: str = None): self._add_tensor(param_name, v) def _add_tensor(self, name: str, tensor: torch.Tensor): + + # transformers will handle _checkpoint_conversion_mapping automatically if is_immediate_saving=False + name = revert_checkpoint_conversion_mapping(name, self.reverse_checkpoint_conversion_mapping) + if isinstance(tensor, torch.Tensor) and tensor.device.type == "meta": self.skipped_meta_tensors.append(name) return diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index 70bb311f7..3f973514d 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -35,10 +35,12 @@ from auto_round.special_model_handler import update_module from auto_round.utils import ( SUPPORTED_LAYER_TYPES, + apply_checkpoint_conversion_mapping, check_start_with_block_name, check_to_quantized, find_matching_blocks, get_block_names, + get_checkpoint_conversion_mapping, get_module, is_hpex_available, is_transformers_version_greater_or_equal_5, @@ -270,12 +272,15 @@ def get_layer_config(model, quantization_config): ) # Determine the quantization block list + checkpoint_conversion_mapping = get_checkpoint_conversion_mapping(model) quant_block_list = getattr(quantization_config, "quant_block_list", None) if quant_block_list is not None: # Handle nested list format: [[block1, block2, ...], ...] -> [prefix1, ...] if quant_block_list and isinstance(quant_block_list[0], (list, tuple)): for i in range(len(quant_block_list)): - quant_block_list[i] = os.path.commonprefix(quant_block_list[i]).rstrip(".") + quant_block_list[i] = apply_checkpoint_conversion_mapping( + os.path.commonprefix(quant_block_list[i]).rstrip("."), checkpoint_conversion_mapping + ) elif quant_block_list is None: to_quant_block_names = getattr(quantization_config, "block_name_to_quantize", None) # Prioritize this parameter if to_quant_block_names is None: @@ -292,6 +297,10 @@ def get_layer_config(model, quantization_config): # Speed up the matching for i in range(len(quant_block_list)): quant_block_list[i] = os.path.commonprefix(quant_block_list[i]).rstrip(".") + for i in range(len(quant_block_list)): + quant_block_list[i] = apply_checkpoint_conversion_mapping( + quant_block_list[i], checkpoint_conversion_mapping + ) # Get layer names that will be quantized layer_names = [] diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index 3b9e1c234..792f113a6 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -1055,3 +1055,68 @@ def infer_bits_by_data_type(data_type: str): if str.isdigit(data_type[len(supported_dtype)]): return int(data_type[len(supported_dtype)]) return None + + +def get_checkpoint_conversion_mapping(model): + """Get the checkpoint conversion mapping for a given model, if it exists.""" + checkpoint_conversion_mapping = {} + + # transformers <= 5.3.0 use _checkpoint_conversion_mapping + checkpoint_conversion_mapping.update(getattr(model, "_checkpoint_conversion_mapping", {})) + + # transformers > 5.3.0 use get_checkpoint_conversion_mapping + if hasattr(transformers, "conversion_mapping") and ( + hasattr(model, "config") and hasattr(model.config, "model_type") + ): + from transformers.conversion_mapping import ( + get_checkpoint_conversion_mapping as transformers_get_checkpoint_conversion_mapping, + ) + + conversion_mappings = transformers_get_checkpoint_conversion_mapping(model.config.model_type) + if conversion_mappings is not None: + for conversion_mapping in conversion_mappings: + for source_pattern in conversion_mapping.source_patterns: + checkpoint_conversion_mapping[source_pattern] = conversion_mapping.target_patterns + return checkpoint_conversion_mapping + + +def get_reverse_checkpoint_conversion_mapping(model): + """Get the reverse checkpoint conversion mapping for a given model, if it exists.""" + reverse_checkpoint_conversion_mapping = { + v: k for k, v in getattr(model, "_checkpoint_conversion_mapping", {}).items() + } + + if hasattr(model, "_weight_conversions"): + weight_conversions = model._weight_conversions + for weight_conversion in weight_conversions: + reverse_conversion_mapping = weight_conversion.reverse_transform() + for source_pattern in reverse_conversion_mapping.source_patterns: + reverse_checkpoint_conversion_mapping[source_pattern] = reverse_conversion_mapping.target_patterns + + return reverse_checkpoint_conversion_mapping + + +def revert_checkpoint_conversion_mapping(name: str, key_mapping: dict[str, str]) -> str: + for source_pattern, target_patterns in key_mapping.items(): + if isinstance(target_patterns, str): + target_patterns = [target_patterns] + for target_pattern in target_patterns: + source_pattern = source_pattern.lstrip("^") # strip off un-needed chars and patterns + source_pattern = re.sub(r"\(.*\)", "", source_pattern) + name, n_replace = re.subn(source_pattern, target_pattern, name) + # Early exit of the loop + if n_replace > 0: + return name + return name + + +def apply_checkpoint_conversion_mapping(name: str, key_mapping: dict[str, str]) -> str: + for source_pattern, target_patterns in key_mapping.items(): + if isinstance(target_patterns, str): + target_patterns = [target_patterns] + for target_pattern in target_patterns: + name, n_replace = re.subn(source_pattern, target_pattern, name) + # Early exit of the loop + if n_replace > 0: + return name + return name diff --git a/test/test_cpu/models/test_mllm.py b/test/test_cpu/models/test_mllm.py index 2c0c71bd4..49b33332a 100644 --- a/test/test_cpu/models/test_mllm.py +++ b/test/test_cpu/models/test_mllm.py @@ -236,6 +236,7 @@ def test_qwen2_5(self, tiny_qwen_2_5_vl_model_path): model = Qwen2_5_VLForConditionalGeneration.from_pretrained( quantized_model_path, torch_dtype="auto", device_map="auto" ) + assert model.config.quantization_config.block_name_to_quantize == "model.visual.blocks,model.layers" image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" processor = AutoProcessor.from_pretrained(quantized_model_path) messages = [ diff --git a/test/test_cuda/integrations/test_sglang.py b/test/test_cuda/integrations/test_sglang.py index 196fa2efa..a99530d2c 100644 --- a/test/test_cuda/integrations/test_sglang.py +++ b/test/test_cuda/integrations/test_sglang.py @@ -121,6 +121,40 @@ def test_mixed_ar_format_sglang(self, dataloader): shutil.rmtree(self.save_dir, ignore_errors=True) + def test_qwen2_5_vl_loading(self, tiny_qwen_2_5_vl_model_path): + from auto_round.utils import mllm_load_model + + layer_config = { + "self_attn": {"bits": 8}, + "lm_head": {"bits": 16}, + "mlp": {"bits": 16, "act_bits": 16}, + } + + model, processor, tokenizer, image_processor = mllm_load_model(tiny_qwen_2_5_vl_model_path) + + autoround = AutoRound( + model, + tokenizer, + scheme="W4A16", + iters=1, + nsamples=1, + seqlen=32, + processor=processor, + image_processor=image_processor, + layer_config=layer_config, + ) + + _, quantized_model_path = autoround.quantize_and_save( + output_dir=self.save_dir, + inplace=True, + format="auto_round", + ) + + generated_text = self._run_sglang_inference(quantized_model_path) + print(generated_text) + + assert "!!!" not in generated_text + @pytest.mark.skip_ci(reason="Cannot work well in CI env") def test_awq_format_sglang(self, dataloader): autoround = AutoRound(