From 6069e63fd8c19367d2790d94e84214be3503ec5f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 12:45:35 +0900 Subject: [PATCH 01/36] try --- src/transformers/conversion_mapping.py | 3 ++ src/transformers/core_model_loading.py | 43 ++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 2a6dc23ba9d0..4624354b0963 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -96,6 +96,7 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), WeightRenaming(source_patterns=r"^image_newline", target_patterns="model.image_newline"), ], + "clip_vision": [WeightRenaming(source_patterns=r"vision_model\.(.+)", target_patterns=r"\1")], "video_llava": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), @@ -668,6 +669,8 @@ def get_model_conversion_mapping( ): conversions = extract_weight_conversions_for_model(submodule) if conversions is not None: + for conversion in conversions: + conversions.restrict_to = submodule.__class__.__name__ weight_conversions.extend(conversions) seen_model_types.add(submodule.config.model_type) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index b43d5354e8ac..2d8c9c16b774 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -583,6 +583,7 @@ def process_source_pattern(source_pattern: str, target_pattern: str) -> str: class WeightTransform: source_patterns: str | list[str] = field(init=True) target_patterns: str | list[str] = field(init=True) + restrict_to: str | None = None compiled_sources: re.Pattern = field(init=False) distributed_operation: TensorParallelLayer | None = None @@ -731,6 +732,24 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: return collected_tensors + def allow_transform(self, source_or_target_key: str, model: PreTrainedModel) -> bool: + """Check if `source_key`""" + if self.restrict_to is None: + return True + if self.restrict_to == model.__class__.__name__: + return source_or_target_key in model.state_dict().keys() + + restricted_name, current_module = source_or_target_key, model + for name in source_or_target_key.split("."): + restricted_name = restricted_name.removeprefix(name) + restricted_name = restricted_name.removeprefix(".") + current_module = getattr(current_module, name) + if self.restrict_to == current_module.__class__.__name__: + return restricted_name in current_module.state_dict().keys() + + # If we did not find the module class, return True + return True + @dataclass(slots=True) class WeightRenaming(WeightTransform): @@ -1034,8 +1053,9 @@ def rename_source_key( source_key: str, weight_renamings: list[WeightRenaming], weight_converters: list[WeightConverter], - prefix: str | None = None, + model: PreTrainedModel | None = None, meta_state_dict: dict | None = None, + saving: bool = False, ) -> tuple[str, str | None]: """ Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing @@ -1045,7 +1065,18 @@ def rename_source_key( # 1. apply all renamings in turns (if multiple match, it's the responsibility of the mappings to make sure they # are coherent) for renaming in weight_renamings: - renamed_key, _ = renaming.rename_source_key(renamed_key) + # If we are saving the state_dict, we must check the `restrict_to` modules BEFORE the renaming + if saving: + # Only rename if the restrictions are allowed on that key + if renaming.allow_transform(renamed_key, model): + renamed_key, _ = renaming.rename_source_key(renamed_key) + # If we are loading the state_dict, we must check the `restrict_to` modules AFTER the renaming + else: + source_key = renamed_key + renamed_key, _ = renaming.rename_source_key(renamed_key) + # If after renaming we find that the transform is not allowed due to restrictions, revert + if not renaming.allow_transform(renamed_key, model): + renamed_key = source_key # 2. apply renaming through weight conversions on the key if we have any WeightConverter (here we stop after # the first match, as we assume only 1 converter can match any source key) @@ -1056,6 +1087,7 @@ def rename_source_key( break # 3. check if we need to add or remove prefix if necessary (only during loading, not saving) + prefix = model.base_model_prefix if model is not None else None if prefix is not None and meta_state_dict is not None: if ( renamed_key.startswith(prefix) @@ -1161,7 +1193,6 @@ def convert_and_load_state_dict_in_model( ``` """ - prefix = model.base_model_prefix tp_plan = tp_plan or {} device_map = load_config.device_map or {"": "cpu"} hf_quantizer = load_config.hf_quantizer @@ -1214,11 +1245,11 @@ def convert_and_load_state_dict_in_model( for original_key, tensor in state_dict: # 1. Rename the key according to all renaming pattern and optional weight converter patterns renamed_key, source_pattern = rename_source_key( - original_key, renamings, converters, prefix, meta_model_state_dict + original_key, renamings, converters, model, meta_model_state_dict ) if renamed_key not in meta_model_state_dict and original_key in meta_model_state_dict: # Key should probably not have been renamed but we might need the `prefix` to be added.` - renamed_key, source_pattern = rename_source_key(original_key, [], [], prefix, meta_model_state_dict) + renamed_key, source_pattern = rename_source_key(original_key, [], [], model, meta_model_state_dict) # 2. finally, collect the tensor into the proper converter if renamed_key in meta_model_state_dict: @@ -1377,7 +1408,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) for original_key, tensor in state_dict: # Rename the key according to all renaming pattern and optional weight converter patterns - renamed_key, source_pattern = rename_source_key(original_key, renamings, converters) + renamed_key, source_pattern = rename_source_key(original_key, renamings, converters, model, saving=True) if source_pattern is not None: new_converter = deepcopy(pattern_to_converter[source_pattern]) # each target key gets its own converter instance From b9c885f19ad29a580235b08bf5de742ce4230aa5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 13:06:46 +0900 Subject: [PATCH 02/36] fix --- src/transformers/core_model_loading.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 2d8c9c16b774..9cc6f218e55a 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -701,7 +701,10 @@ def reverse_transform(self) -> WeightTransform: kwargs["operations"] = [op.reverse_op for op in self.operations[::-1]] reverse_transform = self.__class__( - source_patterns=self._original_target_patterns, target_patterns=self._original_source_patterns, **kwargs + source_patterns=self._original_target_patterns, + target_patterns=self._original_source_patterns, + restrict_to=self.restrict_to, + **kwargs, ) return reverse_transform From 9003fd89ed002d8059ed786a70b1901f00b6ba85 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 13:08:09 +0900 Subject: [PATCH 03/36] oupsi typo --- src/transformers/conversion_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 4624354b0963..f95d3a582acd 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -96,7 +96,7 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), WeightRenaming(source_patterns=r"^image_newline", target_patterns="model.image_newline"), ], - "clip_vision": [WeightRenaming(source_patterns=r"vision_model\.(.+)", target_patterns=r"\1")], + "clip_vision_model": [WeightRenaming(source_patterns=r"vision_model\.(.+)", target_patterns=r"\1")], "video_llava": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), From 2572a25f5f38fd614ec145e524e58b0aa99556a7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 13:09:14 +0900 Subject: [PATCH 04/36] oupsi typo --- src/transformers/conversion_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index f95d3a582acd..d06e29ecbd92 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -670,7 +670,7 @@ def get_model_conversion_mapping( conversions = extract_weight_conversions_for_model(submodule) if conversions is not None: for conversion in conversions: - conversions.restrict_to = submodule.__class__.__name__ + conversion.restrict_to = submodule.__class__.__name__ weight_conversions.extend(conversions) seen_model_types.add(submodule.config.model_type) From d0f7fb2bc604cf7103460e5c1a4cb53d6a48aacb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 15:59:21 +0900 Subject: [PATCH 05/36] get rid of dataclasses --- src/transformers/core_model_loading.py | 106 +++++++++++-------------- 1 file changed, 46 insertions(+), 60 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 9cc6f218e55a..892a58a1bc9b 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -25,7 +25,6 @@ from concurrent.futures import Future, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from dataclasses import dataclass, field from itertools import chain from typing import TYPE_CHECKING, Any @@ -579,42 +578,36 @@ def process_source_pattern(source_pattern: str, target_pattern: str) -> str: return source_pattern -@dataclass(slots=True) class WeightTransform: - source_patterns: str | list[str] = field(init=True) - target_patterns: str | list[str] = field(init=True) - restrict_to: str | None = None - compiled_sources: re.Pattern = field(init=False) - - distributed_operation: TensorParallelLayer | None = None - quantization_operation: ConversionOps | None = None - - collected_tensors: dict[str, list[Future]] = field(default_factory=lambda: defaultdict(list), init=False) - layer_targets: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set), init=False) - - # Those are needed to be able to reverse correctly the transform, as the patterns may be processed - _original_source_patterns: list[str] = field(init=False) - _original_target_patterns: list[str] = field(init=False) + # Restrict the attributes that can be attached + __slots__ = ( + "source_patterns", + "target_patterns", + "compiled_sources", + "distributed_operation", + "quantization_operation", + "collected_tensors", + "layer_targets", + "_original_source_patterns", + "_original_target_patterns", + ) - def __setattr__(self, name, value): - if name in ("source_patterns", "target_patterns"): - # We do not allow to re-set the patterns, as they are linked between each other and changing one - # without the other can mess-up with the capturing groups/compiled sources - if hasattr(self, name): - raise ValueError(f"Cannot assign to field {name}, you should create a new instance") - # Switch str to list - elif isinstance(value, str): - value = [value] - object.__setattr__(self, name, value) + def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): + self.source_patterns: list[str] = source_patterns + self.target_patterns: list[str] = target_patterns + # Those are needed to be able to reverse correctly the transform, as the patterns may be processed + self._original_source_patterns = source_patterns.copy() + self._original_target_patterns = target_patterns.copy() - def __post_init__(self): - # Due to how our `_checkpoint_conversion_mapping` mappings are written, we need a few exceptions here - # when instantiating the reverse mapping (i.e. the targets become sources, and sources become targets) - # The issues lie in the sources usually, so here we need to check the targets for the reversed mapping + # Init fields that will be used during conversion + self.distributed_operation: TensorParallelLayer | None = None + self.quantization_operation: ConversionOps | None = None + self.collected_tensors: dict[str, list[Future]] = defaultdict(list) + self.layer_targets: dict[str, set[str]] = defaultdict(set) - # We need to copy the exact original patterns to later reverse (before processing may change them) - self._original_source_patterns = self.source_patterns.copy() - self._original_target_patterns = self.target_patterns.copy() + # We need to process a few exceptions here when instantiating the reverse mapping (i.e. the targets become + # sources, and sources become targets). The issues lie in the sources usually, so here we need to check the + # targets for the reversed mapping # Process target_patterns: detect capturing groups and replace with \1 # Store the original capturing group patterns for reverse mapping @@ -658,6 +651,20 @@ def __post_init__(self): branches.append(f"(?P<{group_name}>{pattern})") self.compiled_sources = re.compile("|".join(branches)) + def __repr__(self): + return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})" + + def __setattr__(self, name, value): + if name in ("source_patterns", "target_patterns"): + # We do not allow to re-set the patterns, as they are linked between each other and changing one + # without the other can mess-up with the capturing groups/compiled sources + if hasattr(self, name): + raise ValueError(f"Cannot assign to field {name}, you should create a new instance") + # Switch str to list + elif isinstance(value, str): + value = [value] + object.__setattr__(self, name, value) + def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future): self.collected_tensors[source_pattern].append(future) self.layer_targets[target_key].add(source_key) @@ -701,10 +708,7 @@ def reverse_transform(self) -> WeightTransform: kwargs["operations"] = [op.reverse_op for op in self.operations[::-1]] reverse_transform = self.__class__( - source_patterns=self._original_target_patterns, - target_patterns=self._original_source_patterns, - restrict_to=self.restrict_to, - **kwargs, + source_patterns=self._original_target_patterns, target_patterns=self._original_source_patterns, **kwargs ) return reverse_transform @@ -735,26 +739,7 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: return collected_tensors - def allow_transform(self, source_or_target_key: str, model: PreTrainedModel) -> bool: - """Check if `source_key`""" - if self.restrict_to is None: - return True - if self.restrict_to == model.__class__.__name__: - return source_or_target_key in model.state_dict().keys() - - restricted_name, current_module = source_or_target_key, model - for name in source_or_target_key.split("."): - restricted_name = restricted_name.removeprefix(name) - restricted_name = restricted_name.removeprefix(".") - current_module = getattr(current_module, name) - if self.restrict_to == current_module.__class__.__name__: - return restricted_name in current_module.state_dict().keys() - # If we did not find the module class, return True - return True - - -@dataclass(slots=True) class WeightRenaming(WeightTransform): # Special case of WeightTransform that only renames keys without any conversion. @@ -799,12 +784,13 @@ def convert( ) -@dataclass(slots=True) class WeightConverter(WeightTransform): - operations: list[ConversionOps] = field(default_factory=list, repr=False) + __slots__ = WeightTransform.__slots__ + ("operations",) + + def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): + super().__init__(source_patterns, target_patterns) + self.operations: list[ConversionOps] = [] - def __post_init__(self): - WeightTransform.__post_init__(self) if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2: # We allow many-to-many only if we use an internal operation that can handle it if not any(isinstance(op, _INTERNAL_MANY_TO_MANY_CONVERSIONS) for op in self.operations): From 11bc494d859e1ea7a9a7ed0ba2d9e0e912597459 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 17:06:00 +0900 Subject: [PATCH 06/36] try --- src/transformers/conversion_mapping.py | 38 +++++++------ src/transformers/core_model_loading.py | 75 +++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 22 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index d06e29ecbd92..7c6fe275f15c 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -22,9 +22,11 @@ Concatenate, ErnieFuseAndSplitTextVisionExperts, MergeModulelist, + PrefixChange, Transpose, WeightConverter, WeightRenaming, + WeightTransform, ) @@ -96,7 +98,7 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), WeightRenaming(source_patterns=r"^image_newline", target_patterns="model.image_newline"), ], - "clip_vision_model": [WeightRenaming(source_patterns=r"vision_model\.(.+)", target_patterns=r"\1")], + "clip_vision_model": [PrefixChange(prefix_to_remove="vision_model")], "video_llava": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), @@ -624,10 +626,16 @@ def register_checkpoint_conversion_mapping( _checkpoint_conversion_mapping_cache[model_type] = mapping -def extract_weight_conversions_for_model(model: PreTrainedModel) -> list[WeightConverter | WeightRenaming] | None: +def extract_weight_conversions_for_model(model: PreTrainedModel, model_prefix: str) -> list[WeightTransform] | None: model_type = getattr(model.config, "model_type", None) if model_type is not None: model_specific_conversions = get_checkpoint_conversion_mapping(model_type) + # In this case, add the prefix to `PrefixChange` instances, in order to know where to add/remove the prefix + if model_prefix != "": + for i, conversion in enumerate(model_specific_conversions): + # In this case, add the prefix + if isinstance(conversion, PrefixChange): + model_specific_conversions[i] = conversion.with_submodel_prefix(model_prefix) return model_specific_conversions return None @@ -637,7 +645,7 @@ def get_model_conversion_mapping( key_mapping: dict[str, str] | None = None, hf_quantizer: HfQuantizer | None = None, add_legacy: bool = True, -) -> list[WeightConverter | WeightRenaming]: +) -> list[WeightTransform]: """ For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming `_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping. @@ -652,26 +660,16 @@ def get_model_conversion_mapping( if key_mapping is not None: weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()] - # Model have several `PreTrainedModel` within with the same model type - # For ex: XForConditionalGeneration -> XModel. We don't want to apply the same - # conversion pattern twice because of that + # Model have several `PreTrainedModel` within with the same model type, for example: XForConditionalGeneration -> XModel + # We don't want to apply the same conversion pattern twice because of that seen_model_types = set() - if (conversions := extract_weight_conversions_for_model(model)) is not None: - weight_conversions.extend(conversions) - seen_model_types.add(model.config.model_type) - # Recurse over submodules and collect all conversions - for submodule in model.modules(): - if ( - submodule is not model - and isinstance(submodule, PreTrainedModel) - and submodule.config.model_type not in seen_model_types - ): - conversions = extract_weight_conversions_for_model(submodule) + for name, submodule in model.named_modules(): + if isinstance(submodule, PreTrainedModel) and submodule.config.model_type not in seen_model_types: + conversions = extract_weight_conversions_for_model(submodule, name) if conversions is not None: - for conversion in conversions: - conversion.restrict_to = submodule.__class__.__name__ - weight_conversions.extend(conversions) + # Important: we want conversions for submodels to appear first!! + weight_conversions = conversions + weight_conversions seen_model_types.add(submodule.config.model_type) if add_legacy: diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 892a58a1bc9b..e6b44d6902f2 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -654,6 +654,13 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list def __repr__(self): return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})" + def __eq__(self, other: WeightTransform): + return ( + self.__class__ is other.__class__ + and self._original_source_patterns == other._original_source_patterns + and self._original_target_patterns == other._original_target_patterns + ) + def __setattr__(self, name, value): if name in ("source_patterns", "target_patterns"): # We do not allow to re-set the patterns, as they are linked between each other and changing one @@ -743,6 +750,9 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: class WeightRenaming(WeightTransform): # Special case of WeightTransform that only renames keys without any conversion. + # Needs to be empty, otherwise the class will not be slotted + __slots__ = () + def convert( self, layer_name: str, @@ -777,6 +787,53 @@ def convert( return collected_tensors +class PrefixChange(WeightRenaming): + # Special case of weight renaming, used to easily add/remove a prefix while removing/adding it back + # easily as well during saving + + __slots__ = ( + "prefix_to_add", + "prefix_to_remove", + "_prefix_was_changed", + ) + + def __init__( + self, prefix_to_add: str | None = None, prefix_to_remove: str | None = None, model_prefix: str | None = None + ): + if prefix_to_add is None ^ prefix_to_remove is not None: + raise ValueError("You must provide only one of `prefix_to_add` and `prefix_to_remove`") + + self.prefix_to_add = prefix_to_add + self.prefix_to_remove = prefix_to_remove + model_prefix = "" if model_prefix is None else model_prefix + + if prefix_to_add is not None: + super().__init__( + source_patterns=rf"^{model_prefix}\.(.+)$", target_patterns=rf"{model_prefix}\.{prefix_to_add}\.\1" + ) + else: + super().__init__( + source_patterns=rf"^{model_prefix}\.{prefix_to_remove}\.(.+)$", target_patterns=rf"{model_prefix}\.\1" + ) + + # Flag to signal at runtime if the instance was used, i.e. if the checkpoints matched the added/removed + # prefix. If it ends-up being True, the opposite will be used when saving + self._prefix_was_changed = False + + def rename_source_key(self, source_key: str): + renamed_key, source_pattern_that_matched = super().rename_source_key(source_key) + if renamed_key != source_key: + self._prefix_was_changed = True + + def prefix_was_changed(self): + return self._prefix_was_changed + + def with_submodel_prefix(self, prefix: str) -> PrefixChange: + return PrefixChange( + prefix_to_add=self.prefix_to_add, prefix_to_remove=self.prefix_to_remove, model_prefix=prefix + ) + + # List of classes that are known to be able to use m:n _INTERNAL_MANY_TO_MANY_CONVERSIONS = ( ErnieFuseAndSplitTextVisionExperts, @@ -785,7 +842,7 @@ def convert( class WeightConverter(WeightTransform): - __slots__ = WeightTransform.__slots__ + ("operations",) + __slots__ = ("operations",) def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): super().__init__(source_patterns, target_patterns) @@ -1363,7 +1420,21 @@ def convert_and_load_state_dict_in_model( thread_pool.shutdown(wait=False, cancel_futures=True) # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user) - model._weight_conversions = weight_mapping + model_specific_conversions = [] + for conversion in weight_mapping: + # For a prefix change, we need to update at runtime depending on whether the checkpoint already had the correct + # format or not - otherwise, we may end up adding twice the same prefix + if isinstance(conversion, PrefixChange): + used_conversion = next( + used_conversion for used_conversion in param_name_to_load.values() if used_conversion == conversion + ) + # Add the prefix switch to the saved conversion ONLY if it was used at runtime + if used_conversion.prefix_was_changed(): + model_specific_conversions.append(conversion) + else: + model_specific_conversions.append(conversion) + model._weight_conversions = model_specific_conversions + return loading_info, disk_offload_index From 8de7e0fb2a704557b32b8a21f40d65fc3c9ae39a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 17:09:00 +0900 Subject: [PATCH 07/36] oupsi --- src/transformers/core_model_loading.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index e6b44d6902f2..be69f8ad3522 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -844,9 +844,11 @@ def with_submodel_prefix(self, prefix: str) -> PrefixChange: class WeightConverter(WeightTransform): __slots__ = ("operations",) - def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): + def __init__( + self, source_patterns: str | list[str], target_patterns: str | list[str], operations: list[ConversionOps] + ): super().__init__(source_patterns, target_patterns) - self.operations: list[ConversionOps] = [] + self.operations: list[ConversionOps] = operations if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2: # We allow many-to-many only if we use an internal operation that can handle it From 7f38c23cf9a669e352becd6ab4f2acf4f21c5398 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 17:11:00 +0900 Subject: [PATCH 08/36] revert from before --- src/transformers/core_model_loading.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index be69f8ad3522..5e8baa950b5c 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1101,9 +1101,8 @@ def rename_source_key( source_key: str, weight_renamings: list[WeightRenaming], weight_converters: list[WeightConverter], - model: PreTrainedModel | None = None, + prefix: str | None = None, meta_state_dict: dict | None = None, - saving: bool = False, ) -> tuple[str, str | None]: """ Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing @@ -1113,18 +1112,7 @@ def rename_source_key( # 1. apply all renamings in turns (if multiple match, it's the responsibility of the mappings to make sure they # are coherent) for renaming in weight_renamings: - # If we are saving the state_dict, we must check the `restrict_to` modules BEFORE the renaming - if saving: - # Only rename if the restrictions are allowed on that key - if renaming.allow_transform(renamed_key, model): - renamed_key, _ = renaming.rename_source_key(renamed_key) - # If we are loading the state_dict, we must check the `restrict_to` modules AFTER the renaming - else: - source_key = renamed_key - renamed_key, _ = renaming.rename_source_key(renamed_key) - # If after renaming we find that the transform is not allowed due to restrictions, revert - if not renaming.allow_transform(renamed_key, model): - renamed_key = source_key + renamed_key, _ = renaming.rename_source_key(renamed_key) # 2. apply renaming through weight conversions on the key if we have any WeightConverter (here we stop after # the first match, as we assume only 1 converter can match any source key) @@ -1135,7 +1123,6 @@ def rename_source_key( break # 3. check if we need to add or remove prefix if necessary (only during loading, not saving) - prefix = model.base_model_prefix if model is not None else None if prefix is not None and meta_state_dict is not None: if ( renamed_key.startswith(prefix) @@ -1241,6 +1228,7 @@ def convert_and_load_state_dict_in_model( ``` """ + prefix = model.base_model_prefix tp_plan = tp_plan or {} device_map = load_config.device_map or {"": "cpu"} hf_quantizer = load_config.hf_quantizer @@ -1293,11 +1281,11 @@ def convert_and_load_state_dict_in_model( for original_key, tensor in state_dict: # 1. Rename the key according to all renaming pattern and optional weight converter patterns renamed_key, source_pattern = rename_source_key( - original_key, renamings, converters, model, meta_model_state_dict + original_key, renamings, converters, prefix, meta_model_state_dict ) if renamed_key not in meta_model_state_dict and original_key in meta_model_state_dict: # Key should probably not have been renamed but we might need the `prefix` to be added.` - renamed_key, source_pattern = rename_source_key(original_key, [], [], model, meta_model_state_dict) + renamed_key, source_pattern = rename_source_key(original_key, [], [], prefix, meta_model_state_dict) # 2. finally, collect the tensor into the proper converter if renamed_key in meta_model_state_dict: @@ -1470,7 +1458,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) for original_key, tensor in state_dict: # Rename the key according to all renaming pattern and optional weight converter patterns - renamed_key, source_pattern = rename_source_key(original_key, renamings, converters, model, saving=True) + renamed_key, source_pattern = rename_source_key(original_key, renamings, converters) if source_pattern is not None: new_converter = deepcopy(pattern_to_converter[source_pattern]) # each target key gets its own converter instance From b98194d3336db3a87b79946ce0e5740d0187f5d5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 17:12:26 +0900 Subject: [PATCH 09/36] fix --- src/transformers/core_model_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 5e8baa950b5c..1932394719f0 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -596,8 +596,8 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list self.source_patterns: list[str] = source_patterns self.target_patterns: list[str] = target_patterns # Those are needed to be able to reverse correctly the transform, as the patterns may be processed - self._original_source_patterns = source_patterns.copy() - self._original_target_patterns = target_patterns.copy() + self._original_source_patterns = self.source_patterns.copy() + self._original_target_patterns = self.target_patterns.copy() # Init fields that will be used during conversion self.distributed_operation: TensorParallelLayer | None = None From 4c05d6e586da8d30c1d8e50b513bbf9740977c84 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 17:16:39 +0900 Subject: [PATCH 10/36] add parenthesis --- src/transformers/core_model_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 1932394719f0..2073edb89eed 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -800,7 +800,7 @@ class PrefixChange(WeightRenaming): def __init__( self, prefix_to_add: str | None = None, prefix_to_remove: str | None = None, model_prefix: str | None = None ): - if prefix_to_add is None ^ prefix_to_remove is not None: + if (prefix_to_add is None) ^ (prefix_to_remove is not None): raise ValueError("You must provide only one of `prefix_to_add` and `prefix_to_remove`") self.prefix_to_add = prefix_to_add From b2f8cc8886bbd03a7db54dc6cc97eb1f14fe7e84 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 17:33:40 +0900 Subject: [PATCH 11/36] fix --- src/transformers/conversion_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 7c6fe275f15c..8eb33dea95a0 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -631,7 +631,7 @@ def extract_weight_conversions_for_model(model: PreTrainedModel, model_prefix: s if model_type is not None: model_specific_conversions = get_checkpoint_conversion_mapping(model_type) # In this case, add the prefix to `PrefixChange` instances, in order to know where to add/remove the prefix - if model_prefix != "": + if model_specific_conversions is not None and model_prefix != "": for i, conversion in enumerate(model_specific_conversions): # In this case, add the prefix if isinstance(conversion, PrefixChange): From 80d838691f7aa20e99d14b399432934ebd78e2cf Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 19:39:46 +0900 Subject: [PATCH 12/36] fix --- src/transformers/core_model_loading.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 2073edb89eed..6ea6e6bab40a 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -824,6 +824,7 @@ def rename_source_key(self, source_key: str): renamed_key, source_pattern_that_matched = super().rename_source_key(source_key) if renamed_key != source_key: self._prefix_was_changed = True + return renamed_key, source_pattern_that_matched def prefix_was_changed(self): return self._prefix_was_changed @@ -1415,11 +1416,7 @@ def convert_and_load_state_dict_in_model( # For a prefix change, we need to update at runtime depending on whether the checkpoint already had the correct # format or not - otherwise, we may end up adding twice the same prefix if isinstance(conversion, PrefixChange): - used_conversion = next( - used_conversion for used_conversion in param_name_to_load.values() if used_conversion == conversion - ) - # Add the prefix switch to the saved conversion ONLY if it was used at runtime - if used_conversion.prefix_was_changed(): + if conversion.prefix_was_changed(): model_specific_conversions.append(conversion) else: model_specific_conversions.append(conversion) From d3cc313950a81e6e10cb4187592e5f660c4d3444 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 20:01:28 +0900 Subject: [PATCH 13/36] fixes --- src/transformers/conversion_mapping.py | 3 +-- src/transformers/core_model_loading.py | 22 +++++++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 8eb33dea95a0..4d6ad4e481a9 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -668,8 +668,7 @@ def get_model_conversion_mapping( if isinstance(submodule, PreTrainedModel) and submodule.config.model_type not in seen_model_types: conversions = extract_weight_conversions_for_model(submodule, name) if conversions is not None: - # Important: we want conversions for submodels to appear first!! - weight_conversions = conversions + weight_conversions + weight_conversions.extend(conversions) seen_model_types.add(submodule.config.model_type) if add_legacy: diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 6ea6e6bab40a..a7ad081a5d95 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -794,6 +794,7 @@ class PrefixChange(WeightRenaming): __slots__ = ( "prefix_to_add", "prefix_to_remove", + "model_prefix", "_prefix_was_changed", ) @@ -805,15 +806,17 @@ def __init__( self.prefix_to_add = prefix_to_add self.prefix_to_remove = prefix_to_remove - model_prefix = "" if model_prefix is None else model_prefix + self.model_prefix = "" if model_prefix is None else model_prefix if prefix_to_add is not None: super().__init__( - source_patterns=rf"^{model_prefix}\.(.+)$", target_patterns=rf"{model_prefix}\.{prefix_to_add}\.\1" + source_patterns=rf"^{self.model_prefix}\.(.+)$", + target_patterns=rf"{self.model_prefix}\.{prefix_to_add}\.\1", ) else: super().__init__( - source_patterns=rf"^{model_prefix}\.{prefix_to_remove}\.(.+)$", target_patterns=rf"{model_prefix}\.\1" + source_patterns=rf"^{self.model_prefix}\.{prefix_to_remove}\.(.+)$", + target_patterns=rf"{self.model_prefix}\.\1", ) # Flag to signal at runtime if the instance was used, i.e. if the checkpoints matched the added/removed @@ -826,6 +829,19 @@ def rename_source_key(self, source_key: str): self._prefix_was_changed = True return renamed_key, source_pattern_that_matched + def reverse_transform(self) -> WeightTransform: + """Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations.""" + # TODO: check this and relax when quantizer have `reverse_op` + if self.quantization_operation is not None: + raise ValueError("Cannot reverse the transform with TP or quantization") + + if self.prefix_to_add is not None: + reverse_transform = PrefixChange(prefix_to_remove=self.prefix_to_add, model_prefix=self.model_prefix) + else: + reverse_transform = PrefixChange(prefix_to_add=self.prefix_to_remove, model_prefix=self.model_prefix) + + return reverse_transform + def prefix_was_changed(self): return self._prefix_was_changed From e925ce86c5d8fca4074eb3089fc8632d32d4cc7a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 20:11:48 +0900 Subject: [PATCH 14/36] need to revert the order for saving --- src/transformers/core_model_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index a7ad081a5d95..c5bc84a62de9 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1436,7 +1436,7 @@ def convert_and_load_state_dict_in_model( model_specific_conversions.append(conversion) else: model_specific_conversions.append(conversion) - model._weight_conversions = model_specific_conversions + model._weight_conversions = model_specific_conversions[::-1] return loading_info, disk_offload_index From e792532495e409e42dbf495dc1d415da89629c7a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 20:13:06 +0900 Subject: [PATCH 15/36] comment --- src/transformers/core_model_loading.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index c5bc84a62de9..463c0e8e0a31 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1436,6 +1436,7 @@ def convert_and_load_state_dict_in_model( model_specific_conversions.append(conversion) else: model_specific_conversions.append(conversion) + # Important: we need to revert the order here, so that potential conversions from submodels are performed first model._weight_conversions = model_specific_conversions[::-1] return loading_info, disk_offload_index From c48218da5f246e04652ea9973bcd4f48f5def47d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 22:14:23 +0900 Subject: [PATCH 16/36] a bit more general --- src/transformers/core_model_loading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 463c0e8e0a31..e2dfc272ae04 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -846,8 +846,9 @@ def prefix_was_changed(self): return self._prefix_was_changed def with_submodel_prefix(self, prefix: str) -> PrefixChange: + new_prefix = f"{prefix}.{self.model_prefix}" if self.model_prefix != "" else prefix return PrefixChange( - prefix_to_add=self.prefix_to_add, prefix_to_remove=self.prefix_to_remove, model_prefix=prefix + prefix_to_add=self.prefix_to_add, prefix_to_remove=self.prefix_to_remove, model_prefix=new_prefix ) From 01cda19ab5c5f08a9d677ad3221de0cf8680bcd6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 22:23:31 +0900 Subject: [PATCH 17/36] simplify --- src/transformers/core_model_loading.py | 45 ++++++++------------------ 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index e2dfc272ae04..74fa4a40b5b0 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -590,6 +590,7 @@ class WeightTransform: "layer_targets", "_original_source_patterns", "_original_target_patterns", + "_was_used", ) def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): @@ -605,6 +606,9 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list self.collected_tensors: dict[str, list[Future]] = defaultdict(list) self.layer_targets: dict[str, set[str]] = defaultdict(set) + # Flag to notice if the Transform was used + self._was_used = False + # We need to process a few exceptions here when instantiating the reverse mapping (i.e. the targets become # sources, and sources become targets). The issues lie in the sources usually, so here we need to check the # targets for the reversed mapping @@ -654,13 +658,6 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list def __repr__(self): return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})" - def __eq__(self, other: WeightTransform): - return ( - self.__class__ is other.__class__ - and self._original_source_patterns == other._original_source_patterns - and self._original_target_patterns == other._original_target_patterns - ) - def __setattr__(self, name, value): if name in ("source_patterns", "target_patterns"): # We do not allow to re-set the patterns, as they are linked between each other and changing one @@ -688,6 +685,9 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]: if match_object is None: return source_key, None + # We have a match, so the Transform was used + self._was_used = True + # Find the source that produced the match (it's the first group that matched, as the search stops after first branch match) matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None) source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])] @@ -746,6 +746,10 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: return collected_tensors + def was_used(self) -> bool: + """Return whether the current Transform matched any weights during loading/saving""" + return self._was_used + class WeightRenaming(WeightTransform): # Special case of WeightTransform that only renames keys without any conversion. @@ -795,7 +799,6 @@ class PrefixChange(WeightRenaming): "prefix_to_add", "prefix_to_remove", "model_prefix", - "_prefix_was_changed", ) def __init__( @@ -819,16 +822,6 @@ def __init__( target_patterns=rf"{self.model_prefix}\.\1", ) - # Flag to signal at runtime if the instance was used, i.e. if the checkpoints matched the added/removed - # prefix. If it ends-up being True, the opposite will be used when saving - self._prefix_was_changed = False - - def rename_source_key(self, source_key: str): - renamed_key, source_pattern_that_matched = super().rename_source_key(source_key) - if renamed_key != source_key: - self._prefix_was_changed = True - return renamed_key, source_pattern_that_matched - def reverse_transform(self) -> WeightTransform: """Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations.""" # TODO: check this and relax when quantizer have `reverse_op` @@ -842,9 +835,6 @@ def reverse_transform(self) -> WeightTransform: return reverse_transform - def prefix_was_changed(self): - return self._prefix_was_changed - def with_submodel_prefix(self, prefix: str) -> PrefixChange: new_prefix = f"{prefix}.{self.model_prefix}" if self.model_prefix != "" else prefix return PrefixChange( @@ -1427,16 +1417,9 @@ def convert_and_load_state_dict_in_model( # `cancel_futures=True` in case the program was interrupted, to avoid wasting time on exit thread_pool.shutdown(wait=False, cancel_futures=True) - # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user) - model_specific_conversions = [] - for conversion in weight_mapping: - # For a prefix change, we need to update at runtime depending on whether the checkpoint already had the correct - # format or not - otherwise, we may end up adding twice the same prefix - if isinstance(conversion, PrefixChange): - if conversion.prefix_was_changed(): - model_specific_conversions.append(conversion) - else: - model_specific_conversions.append(conversion) + # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but + # only if it was used, i.e. it matched any weight from the checkpoints + model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()] # Important: we need to revert the order here, so that potential conversions from submodels are performed first model._weight_conversions = model_specific_conversions[::-1] From e38bad1b466d8a39eae0339f7f37c94ad1453c92 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 23:03:24 +0900 Subject: [PATCH 18/36] start adding tests --- tests/utils/test_core_model_loading.py | 50 ++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 2875f44088a7..6f0683a8495e 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -27,6 +27,7 @@ LinearToConv3d, MergeModulelist, PermuteForRope, + PrefixChange, WeightConverter, WeightRenaming, build_glob_alternation, @@ -778,6 +779,55 @@ def test_register_checkpoint_conversion_mapping_overwrites(self): self.assertEqual(len(get_checkpoint_conversion_mapping("foobarbaz")), 2) + def test_can_add_and_remove_full_parts(self): + model = DummyRoot() + model.config = PretrainedConfig() + + bad_serialized_checkpoints = {f"bad_name.{k}": v.clone() for k, v in model.state_dict().items()} + weight_mapping = [PrefixChange(remove_prefix="bad_name")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=weight_mapping), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=weight_mapping), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not re-add the bad prefix since it was + # not present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + if __name__ == "__main__": unittest.main() From 2014ee152c923add13b9dcbc7be9d435981d4323 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 23:04:59 +0900 Subject: [PATCH 19/36] typo --- tests/utils/test_core_model_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 6f0683a8495e..1d484cf229e9 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -784,7 +784,7 @@ def test_can_add_and_remove_full_parts(self): model.config = PretrainedConfig() bad_serialized_checkpoints = {f"bad_name.{k}": v.clone() for k, v in model.state_dict().items()} - weight_mapping = [PrefixChange(remove_prefix="bad_name")] + weight_mapping = [PrefixChange(prefix_to_remove="bad_name")] loading_info, _ = convert_and_load_state_dict_in_model( model, From f8acd0ab72a5720aef4fc5006328fbc66997157e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 15 Apr 2026 23:14:17 +0900 Subject: [PATCH 20/36] fix dot --- src/transformers/core_model_loading.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 74fa4a40b5b0..b4e23fcdd869 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -810,17 +810,12 @@ def __init__( self.prefix_to_add = prefix_to_add self.prefix_to_remove = prefix_to_remove self.model_prefix = "" if model_prefix is None else model_prefix + prefix = rf"{self.model_prefix}\." if self.model_prefix != "" else "" if prefix_to_add is not None: - super().__init__( - source_patterns=rf"^{self.model_prefix}\.(.+)$", - target_patterns=rf"{self.model_prefix}\.{prefix_to_add}\.\1", - ) + super().__init__(source_patterns=rf"^{prefix}(.+)$", target_patterns=rf"{prefix}{prefix_to_add}\.\1") else: - super().__init__( - source_patterns=rf"^{self.model_prefix}\.{prefix_to_remove}\.(.+)$", - target_patterns=rf"{self.model_prefix}\.\1", - ) + super().__init__(source_patterns=rf"^{prefix}{prefix_to_remove}\.(.+)$", target_patterns=rf"{prefix}\1") def reverse_transform(self) -> WeightTransform: """Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations.""" From b0f5c268dea92b0e702faaaf971a859ac9b6bd30 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 09:09:24 +0900 Subject: [PATCH 21/36] fix --- tests/utils/test_core_model_loading.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 1d484cf229e9..e22039a4c169 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest from types import SimpleNamespace @@ -789,7 +790,7 @@ def test_can_add_and_remove_full_parts(self): loading_info, _ = convert_and_load_state_dict_in_model( model, bad_serialized_checkpoints, - LoadStateDictConfig(weight_mapping=weight_mapping), + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), tp_plan=None, ) @@ -811,7 +812,7 @@ def test_can_add_and_remove_full_parts(self): loading_info, _ = convert_and_load_state_dict_in_model( model, good_serialized_checkpoints, - LoadStateDictConfig(weight_mapping=weight_mapping), + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), tp_plan=None, ) From e3bc9e82bcb91abe526c76fd8c63ae304e464c0d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 09:55:09 +0900 Subject: [PATCH 22/36] more tests --- src/transformers/core_model_loading.py | 6 ++- tests/utils/test_core_model_loading.py | 58 ++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index b4e23fcdd869..5c568781ba3e 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -813,7 +813,11 @@ def __init__( prefix = rf"{self.model_prefix}\." if self.model_prefix != "" else "" if prefix_to_add is not None: - super().__init__(source_patterns=rf"^{prefix}(.+)$", target_patterns=rf"{prefix}{prefix_to_add}\.\1") + # We use a lookbehind to avoid adding the prefix if we detect that it's already present + super().__init__( + source_patterns=rf"^{prefix}(?:(?!{prefix_to_add}\.))(.+)$", + target_patterns=rf"{prefix}{prefix_to_add}\.\1", + ) else: super().__init__(source_patterns=rf"^{prefix}{prefix_to_remove}\.(.+)$", target_patterns=rf"{prefix}\1") diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index e22039a4c169..b9112876c8d6 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -212,10 +212,11 @@ class DummyRoot(nn.Module): base_model_prefix = "model" config: PretrainedConfig - def __init__(self, add_extra_moe=False): + def __init__(self, add_extra_moe=False, with_mlp=True): super().__init__() self.model = DummyTopModel(add_extra_moe) - self.mlp = DummyMLP() + if with_mlp: + self.mlp = DummyMLP() class TestConvertAndLoadStateDict(unittest.TestCase): @@ -780,7 +781,7 @@ def test_register_checkpoint_conversion_mapping_overwrites(self): self.assertEqual(len(get_checkpoint_conversion_mapping("foobarbaz")), 2) - def test_can_add_and_remove_full_parts(self): + def test_can_remove_prefix(self): model = DummyRoot() model.config = PretrainedConfig() @@ -829,6 +830,57 @@ def test_can_add_and_remove_full_parts(self): for k, v in saved_state_dict.items(): self.assertTrue((v == good_serialized_checkpoints[k]).all()) + def test_can_add_prefix(self): + # we cannot have another param next to the model, otherwise the prefix adding will already be added even with correct + # checkpoints starting with the prefix + model = DummyRoot(with_mlp=False) + model.config = PretrainedConfig() + + bad_serialized_checkpoints = {k.removeprefix("model."): v.clone() for k, v in model.state_dict().items()} + weight_mapping = [PrefixChange(prefix_to_add="model")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not remove the prefix since it was + # already present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + if __name__ == "__main__": unittest.main() From 2cb16331693dcc18e9cde12805d5c27eae32f8a1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 10:01:37 +0900 Subject: [PATCH 23/36] add harder tests --- src/transformers/core_model_loading.py | 2 +- tests/utils/test_core_model_loading.py | 103 +++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 5c568781ba3e..47e1880434c0 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -813,8 +813,8 @@ def __init__( prefix = rf"{self.model_prefix}\." if self.model_prefix != "" else "" if prefix_to_add is not None: - # We use a lookbehind to avoid adding the prefix if we detect that it's already present super().__init__( + # We use a lookbehind to avoid adding the prefix if we detect that it's already present source_patterns=rf"^{prefix}(?:(?!{prefix_to_add}\.))(.+)$", target_patterns=rf"{prefix}{prefix_to_add}\.\1", ) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index b9112876c8d6..b2da9efc06a2 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -881,6 +881,109 @@ def test_can_add_prefix(self): for k, v in saved_state_dict.items(): self.assertTrue((v == good_serialized_checkpoints[k]).all()) + def test_can_remove_prefix_submodule(self): + model = DummyRoot() + model.config = PretrainedConfig() + + bad_serialized_checkpoints = { + f"model.layers.bad_name.{k.replace('model.layers.', '')}" if "model.layers." in k else k: v.clone() + for k, v in model.state_dict().items() + } + weight_mapping = [PrefixChange(prefix_to_remove="bad_name", model_prefix="model.layers")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not re-add the bad prefix since it was + # not present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + + def test_can_add_prefix_submodule(self): + # we cannot have another param next to the model, otherwise the prefix adding will already be added even with correct + # checkpoints starting with the prefix + model = DummyRoot(with_mlp=False) + model.config = PretrainedConfig() + + bad_serialized_checkpoints = {k.replace(".layers.", "."): v.clone() for k, v in model.state_dict().items()} + weight_mapping = [PrefixChange(prefix_to_add="layers", model_prefix="model")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not remove the prefix since it was + # already present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + if __name__ == "__main__": unittest.main() From d02310f4f5465f5d0fafd83891c2a5c3a2aaaa95 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 10:25:43 +0900 Subject: [PATCH 24/36] fix --- src/transformers/core_model_loading.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 47e1880434c0..2259b676f77f 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1438,6 +1438,10 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch # Do not resave with the legacy renaming, if present weight_conversions = get_model_conversion_mapping(model, add_legacy=False) + # If the model had no `_weight_conversions` attached, drop any PrefixChange transform - this is because the + # model was almost surely instantiated from scratch, and PrefixChange with `prefix_to_remove` would otherwise + # add a unwanted prefix (as we dont have any information about whether the prefix was there or not during load) + weight_conversions = [x for x in weight_conversions if not isinstance(x, PrefixChange)] weight_conversions = weight_conversions if len(weight_conversions) > 0 else None # We did not find any operations to perform -> quick escape From fa3bc4726f95d338f272de5ce889c437523c2aff Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 10:31:08 +0900 Subject: [PATCH 25/36] improve tests --- tests/utils/test_core_model_loading.py | 36 ++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index b2da9efc06a2..a358822d19f8 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -830,6 +830,15 @@ def test_can_remove_prefix(self): for k, v in saved_state_dict.items(): self.assertTrue((v == good_serialized_checkpoints[k]).all()) + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be added when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + def test_can_add_prefix(self): # we cannot have another param next to the model, otherwise the prefix adding will already be added even with correct # checkpoints starting with the prefix @@ -881,6 +890,15 @@ def test_can_add_prefix(self): for k, v in saved_state_dict.items(): self.assertTrue((v == good_serialized_checkpoints[k]).all()) + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be removed when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + def test_can_remove_prefix_submodule(self): model = DummyRoot() model.config = PretrainedConfig() @@ -933,6 +951,15 @@ def test_can_remove_prefix_submodule(self): for k, v in saved_state_dict.items(): self.assertTrue((v == good_serialized_checkpoints[k]).all()) + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be added when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + def test_can_add_prefix_submodule(self): # we cannot have another param next to the model, otherwise the prefix adding will already be added even with correct # checkpoints starting with the prefix @@ -984,6 +1011,15 @@ def test_can_add_prefix_submodule(self): for k, v in saved_state_dict.items(): self.assertTrue((v == good_serialized_checkpoints[k]).all()) + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be removed when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + if __name__ == "__main__": unittest.main() From 3ba24babd21c588132321b21f13c6f678b67ff54 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 10:32:55 +0900 Subject: [PATCH 26/36] comment --- src/transformers/conversion_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 4d6ad4e481a9..959db738ae5a 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -633,7 +633,7 @@ def extract_weight_conversions_for_model(model: PreTrainedModel, model_prefix: s # In this case, add the prefix to `PrefixChange` instances, in order to know where to add/remove the prefix if model_specific_conversions is not None and model_prefix != "": for i, conversion in enumerate(model_specific_conversions): - # In this case, add the prefix + # In this case, add the prefix, as otherwise we don't know where we need to re-add it exactly in the module name chain if isinstance(conversion, PrefixChange): model_specific_conversions[i] = conversion.with_submodel_prefix(model_prefix) return model_specific_conversions From 33917c11e90cc8d93eeed5be7b2b3b63220cf722 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 11:01:47 +0900 Subject: [PATCH 27/36] doc --- src/transformers/core_model_loading.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 2259b676f77f..e130f4ff5800 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -747,7 +747,13 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: return collected_tensors def was_used(self) -> bool: - """Return whether the current Transform matched any weights during loading/saving""" + """ + Return whether the current Transform matched any weights during loading/saving. This is needed as some + weight renaming transforms are not bijective, i.e. if we drop/add full parts of a name with PrefixChange, we + lose some informations that we cannot get back if we don't know if the Transform was used before already (say we + have a prefix to drop, we need to know whether the checkpoints we loaded before contained the said prefix or not + before adding it back, or not, during saving). + """ return self._was_used @@ -792,8 +798,15 @@ def convert( class PrefixChange(WeightRenaming): - # Special case of weight renaming, used to easily add/remove a prefix while removing/adding it back - # easily as well during saving + """ + Special case of WeightRenaming, used to simplify adding/removing full parts of a weight name. The regexes + that are needed for such operations are complex, so this is a much easier API for such cases. + + It also correctly handles the revert operations, which are in general not bijective for addition/removal of full + name parts. Indeed, if we drop/add full parts of a name, we lose some informations that we cannot get back if we don't + know if the Transform was used before. For example, say we have a prefix to drop, we need to know whether the checkpoints + we actually loaded before contained the said prefix or not before adding it back, or not, during saving. + """ __slots__ = ( "prefix_to_add", From 6ee1496d2bd0ae5cd211cdc7e4bc631874f0f807 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 11:10:47 +0900 Subject: [PATCH 28/36] skip in tests --- tests/test_modeling_common.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9dbf44c03c12..a1ee2babe012 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -44,7 +44,7 @@ set_seed, ) from transformers.conversion_mapping import get_model_conversion_mapping -from transformers.core_model_loading import WeightRenaming, process_target_pattern +from transformers.core_model_loading import PrefixChange, WeightRenaming, process_target_pattern from transformers.integrations import HfDeepSpeedConfig from transformers.integrations.deepspeed import ( is_deepspeed_available, @@ -4783,6 +4783,11 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True, skip_base_ conversions = get_model_conversion_mapping(model, add_legacy=False) if len(conversions) == 0: self.skipTest(f"No conversion found for {model_class}") + # The PrefixChange conersions are only there for BC with hub checkpoints, but cannot be tested + # for as we skip them automatically if they are not present in loaded checkpoints (we want to + # mess up the prefixes only if the loaded checkpoints were doing so as well) + if all(isinstance(conversion, PrefixChange) for conversion in conversions): + self.skipTest(f"Only PrefixChange conversions found for {model_class}") # Find the model keys, so the targets according to the conversions model_keys = list(model.state_dict().keys()) @@ -4800,6 +4805,11 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True, skip_base_ # Check that for each conversion entry, we at least map to one key for conversion in conversions: + # The PrefixChange conersions are only there for BC with hub checkpoints, but cannot be tested + # for as we skip them automatically if they are not present in loaded checkpoints (we want to + # mess up the prefixes only if the loaded checkpoints were doing so as well) + if isinstance(conversion, PrefixChange): + continue for source_pattern in conversion.source_patterns: # Some patterns are written for gen-model only and won't be applied on base model if "lm_head" in source_pattern and model_class not in [ From a534dfb56670a283f69413db2d3d50e603838c82 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 15:43:57 +0900 Subject: [PATCH 29/36] fix cohere_asr mapping --- src/transformers/conversion_mapping.py | 22 +++++++++++----------- src/transformers/core_model_loading.py | 6 ++++-- tests/test_modeling_common.py | 7 ++++++- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 959db738ae5a..c03e539e93e3 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -571,17 +571,17 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(r"transf_decoder\._decoder\.final_layer_norm", r"decoder.norm"), WeightRenaming(r"transf_decoder\._decoder\.layers", r"decoder.layers"), WeightRenaming(r"encoder_decoder_proj\.", r"decoder.proj."), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"), - WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"), - WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"), - WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"), - WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"), + WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"), + WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"), + WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"), + WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"), + WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"), + WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"), + WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"), WeightRenaming(r"\.second_sub_layer\.query_net", r".encoder_attn.q_proj"), WeightRenaming(r"\.second_sub_layer\.key_net", r".encoder_attn.k_proj"), WeightRenaming(r"\.second_sub_layer\.value_net", r".encoder_attn.v_proj"), diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index e130f4ff5800..cf7497139e2b 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1432,8 +1432,7 @@ def convert_and_load_state_dict_in_model( # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but # only if it was used, i.e. it matched any weight from the checkpoints model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()] - # Important: we need to revert the order here, so that potential conversions from submodels are performed first - model._weight_conversions = model_specific_conversions[::-1] + model._weight_conversions = model_specific_conversions return loading_info, disk_offload_index @@ -1461,6 +1460,9 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch if weight_conversions is None: return state_dict + # Important: we need to revert the order here, so that potential conversions from submodels are performed first + weight_conversions = weight_conversions[::-1] + # Reverse all Transform to correctly match keys reverse_weight_conversion = [conversion.reverse_transform() for conversion in weight_conversions] # If we are still here, we need to create the (reverse) conversion mapping from scratch diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a1ee2babe012..b77fea367bb8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -5682,7 +5682,12 @@ def compare_state_dicts(state_dict1, state_dict2) -> bool: """Make sure 2 state dicts are the exact same""" # Make sure the keys are the exact same if sorted(state_dict1.keys()) != sorted(state_dict2.keys()): - raise ValueError("The keys of both state dict are not the same") + in1_not2 = sorted(set(state_dict1.keys()) - set(state_dict2.keys())) + in2_not1 = sorted(set(state_dict2.keys()) - set(state_dict1.keys())) + raise ValueError( + f"The keys of both state dict are not the same.\nKeys found in the first item but not second: {in1_not2}" + f"\nKeys found in the second item but not first: {in2_not1}" + ) for k, v1 in state_dict1.items(): v2 = state_dict2[k] From b0bc3420b195cdad080442931c26edf5810488f5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 17:24:24 +0900 Subject: [PATCH 30/36] add other needed models to mapping --- src/transformers/conversion_mapping.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index c03e539e93e3..8c7e5761a6a7 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -77,6 +77,11 @@ "pp_chart2table": "llava", "gemma3n_text": "qwen3_5_text", "qwen3_5_moe_text": "qwen3_5_text", + "altclip_vision_model": "clip_vision_model", + "chinese_clip_vision_model": "clip_vision_model", + "clipseg_vision_model": "clip_vision_model", + "mlcd_vision": "clip_vision_model", + "mlcd": "clip_vision_model", } From b80d9ff7da8dd65ae1b6abbc223acd1bb5ffd5ce Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Apr 2026 17:45:25 +0900 Subject: [PATCH 31/36] add text mappings --- src/transformers/conversion_mapping.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 8c7e5761a6a7..ba2fe8a5bb25 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -80,8 +80,17 @@ "altclip_vision_model": "clip_vision_model", "chinese_clip_vision_model": "clip_vision_model", "clipseg_vision_model": "clip_vision_model", + "metaclip_2_vision_model": "clip_vision_model", "mlcd_vision": "clip_vision_model", "mlcd": "clip_vision_model", + "siglip_vision_model": "clip_vision_model", + "siglip2_vision_model": "clip_vision_model", + "xclip_vision_model": "clip_vision_model", + "clipseg_text_model": "clip_text_model", + "metaclip_2_text_model": "clip_text_model", + "siglip_text_model": "clip_text_model", + "siglip2_text_model": "clip_text_model", + "xclip_text_model": "clip_text_model", } @@ -104,6 +113,7 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^image_newline", target_patterns="model.image_newline"), ], "clip_vision_model": [PrefixChange(prefix_to_remove="vision_model")], + "clip_text_model": [PrefixChange(prefix_to_remove="text_model")], "video_llava": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), From f8d2dffce657fe996363e7e50a2622538ffe6c58 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Apr 2026 10:46:56 +0900 Subject: [PATCH 32/36] add back --- src/transformers/conversion_mapping.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index ba2fe8a5bb25..88155e4b4daa 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -586,13 +586,13 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(r"transf_decoder\._decoder\.final_layer_norm", r"decoder.norm"), WeightRenaming(r"transf_decoder\._decoder\.layers", r"decoder.layers"), WeightRenaming(r"encoder_decoder_proj\.", r"decoder.proj."), - WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"), - WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"), - WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"), - WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"), - WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"), - WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"), - WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.\1.self_attn.q_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.\1.self_attn.k_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.\1.self_attn.v_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.\1.self_attn.o_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.\1.self_attn.relative_k_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.\1.self_attn.bias_u"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.\1.self_attn.bias_v"), WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"), WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"), WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"), From c2a1b6d3e59bdcb264422a541f9c64939f5cdcd5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Apr 2026 10:56:59 +0900 Subject: [PATCH 33/36] better comment --- src/transformers/core_model_loading.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index cf7497139e2b..ddd1ecd3c08b 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1451,8 +1451,9 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch # Do not resave with the legacy renaming, if present weight_conversions = get_model_conversion_mapping(model, add_legacy=False) # If the model had no `_weight_conversions` attached, drop any PrefixChange transform - this is because the - # model was almost surely instantiated from scratch, and PrefixChange with `prefix_to_remove` would otherwise - # add a unwanted prefix (as we dont have any information about whether the prefix was there or not during load) + # model was almost surely instantiated from scratch (at least not from `from_pretrained`), and PrefixChange with + # `prefix_to_remove` would otherwise add a unwanted prefix (as we dont have any information about whether the prefix + # was there or not during load) weight_conversions = [x for x in weight_conversions if not isinstance(x, PrefixChange)] weight_conversions = weight_conversions if len(weight_conversions) > 0 else None From fd1d8e74ba8f9a610a3fafd4e47c801570893d9b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Apr 2026 11:20:43 +0900 Subject: [PATCH 34/36] simplify --- src/transformers/core_model_loading.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index ddd1ecd3c08b..7c107a1f06e4 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -840,12 +840,10 @@ def reverse_transform(self) -> WeightTransform: if self.quantization_operation is not None: raise ValueError("Cannot reverse the transform with TP or quantization") - if self.prefix_to_add is not None: - reverse_transform = PrefixChange(prefix_to_remove=self.prefix_to_add, model_prefix=self.model_prefix) - else: - reverse_transform = PrefixChange(prefix_to_add=self.prefix_to_remove, model_prefix=self.model_prefix) - - return reverse_transform + # Only one of the 2 can ever be used, so 1 is always None + return PrefixChange( + prefix_to_add=self.prefix_to_remove, prefix_to_remove=self.prefix_to_add, model_prefix=self.model_prefix + ) def with_submodel_prefix(self, prefix: str) -> PrefixChange: new_prefix = f"{prefix}.{self.model_prefix}" if self.model_prefix != "" else prefix From a19ad84d58e0f7cec20c5218659b13a8af75334e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Apr 2026 17:06:25 +0900 Subject: [PATCH 35/36] remove overriden test --- tests/models/altclip/test_modeling_altclip.py | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 77aeddc31b11..bf849c031e3a 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -13,11 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch AltCLIP model.""" -import copy import inspect -import os -import re -import tempfile import unittest import numpy as np @@ -25,8 +21,6 @@ from parameterized import parameterized from transformers import AltCLIPConfig, AltCLIPProcessor, AltCLIPTextConfig, AltCLIPVisionConfig -from transformers.conversion_mapping import get_model_conversion_mapping -from transformers.core_model_loading import WeightRenaming, process_target_pattern from transformers.testing_utils import is_flaky, require_torch, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_vision_available @@ -44,7 +38,6 @@ if is_torch_available(): import torch import torch.nn as nn - from safetensors.torch import load_file from transformers import AltCLIPModel, AltCLIPTextModel, AltCLIPVisionModel @@ -472,54 +465,6 @@ def test_model_from_pretrained(self): model = AltCLIPModel.from_pretrained(model_name) self.assertIsNotNone(model) - def test_reverse_loading_mapping(self, check_keys_were_modified=True): - # AltCLIP applies legacy conversion which is never reversed, so we won't get - # matching state dict keys after re-saving it back - - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - # Each individual model is a subtest - with self.subTest(model_class.__name__): - model = model_class(copy.deepcopy(config)) - # Skip if no conversions - conversions = get_model_conversion_mapping(model, add_legacy=False) - if len(conversions) == 0: - self.skipTest(f"No conversion found for {model_class}") - - # Find the model keys, so the targets according to the conversions - model_keys = list(model.state_dict().keys()) - - with tempfile.TemporaryDirectory() as tmpdirname: - # Serialize with reverse mapping - model.save_pretrained(tmpdirname) - state_dict = load_file(os.path.join(tmpdirname, "model.safetensors")) - # Get all the serialized keys that we just saved according to the reverse mapping - serialized_keys = list(state_dict.keys()) - - if check_keys_were_modified: - # They should be different, otherwise we did not perform any mapping - self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") - - # Check that for each conversion entry, we at least map to one key - for conversion in conversions: - for source_pattern in conversion.source_patterns: - # Sometimes the mappings specify keys that are tied, so absent from the saved state dict - if isinstance(conversion, WeightRenaming): - # We need to revert the target pattern to make it compatible with regex search - target_pattern_reversed = conversion.target_patterns[0] - captured_group = process_target_pattern(source_pattern)[1] - if captured_group: - target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group) - if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()): - continue - num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys) - self.assertTrue( - num_matches > 0, - f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " - "This indicates whether that the pattern is not properly written, or that it could not be reversed correctly", - ) - @require_vision @require_torch From 94eaeb043a5b1c85fd3e30471edccf86483a77db Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Apr 2026 12:19:39 +0900 Subject: [PATCH 36/36] deduplicate doc --- src/transformers/core_model_loading.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 7c107a1f06e4..cd0710649c91 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -801,11 +801,6 @@ class PrefixChange(WeightRenaming): """ Special case of WeightRenaming, used to simplify adding/removing full parts of a weight name. The regexes that are needed for such operations are complex, so this is a much easier API for such cases. - - It also correctly handles the revert operations, which are in general not bijective for addition/removal of full - name parts. Indeed, if we drop/add full parts of a name, we lose some informations that we cannot get back if we don't - know if the Transform was used before. For example, say we have a prefix to drop, we need to know whether the checkpoints - we actually loaded before contained the said prefix or not before adding it back, or not, during saving. """ __slots__ = (