Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6069e63
try
Cyrilvallez Apr 15, 2026
b9c885f
fix
Cyrilvallez Apr 15, 2026
9003fd8
oupsi typo
Cyrilvallez Apr 15, 2026
2572a25
oupsi typo
Cyrilvallez Apr 15, 2026
d0f7fb2
get rid of dataclasses
Cyrilvallez Apr 15, 2026
11bc494
try
Cyrilvallez Apr 15, 2026
8de7e0f
oupsi
Cyrilvallez Apr 15, 2026
7f38c23
revert from before
Cyrilvallez Apr 15, 2026
b98194d
fix
Cyrilvallez Apr 15, 2026
4c05d6e
add parenthesis
Cyrilvallez Apr 15, 2026
b2f8cc8
fix
Cyrilvallez Apr 15, 2026
80d8386
fix
Cyrilvallez Apr 15, 2026
d3cc313
fixes
Cyrilvallez Apr 15, 2026
e925ce8
need to revert the order for saving
Cyrilvallez Apr 15, 2026
e792532
comment
Cyrilvallez Apr 15, 2026
c48218d
a bit more general
Cyrilvallez Apr 15, 2026
01cda19
simplify
Cyrilvallez Apr 15, 2026
e38bad1
start adding tests
Cyrilvallez Apr 15, 2026
2014ee1
typo
Cyrilvallez Apr 15, 2026
f8acd0a
fix dot
Cyrilvallez Apr 15, 2026
b0f5c26
fix
Cyrilvallez Apr 16, 2026
e3bc9e8
more tests
Cyrilvallez Apr 16, 2026
2cb1633
add harder tests
Cyrilvallez Apr 16, 2026
d02310f
fix
Cyrilvallez Apr 16, 2026
fa3bc47
improve tests
Cyrilvallez Apr 16, 2026
3ba24ba
comment
Cyrilvallez Apr 16, 2026
33917c1
doc
Cyrilvallez Apr 16, 2026
6ee1496
skip in tests
Cyrilvallez Apr 16, 2026
a534dfb
fix cohere_asr mapping
Cyrilvallez Apr 16, 2026
b0bc342
add other needed models to mapping
Cyrilvallez Apr 16, 2026
b80d9ff
add text mappings
Cyrilvallez Apr 16, 2026
f8d2dff
add back
Cyrilvallez Apr 17, 2026
c2a1b6d
better comment
Cyrilvallez Apr 17, 2026
fd1d8e7
simplify
Cyrilvallez Apr 17, 2026
a19ad84
remove overriden test
Cyrilvallez Apr 17, 2026
f168035
Merge branch 'main' into fix-clips
Cyrilvallez Apr 17, 2026
94eaeb0
deduplicate doc
Cyrilvallez Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
Concatenate,
ErnieFuseAndSplitTextVisionExperts,
MergeModulelist,
PrefixChange,
Transpose,
WeightConverter,
WeightRenaming,
WeightTransform,
)


Expand Down Expand Up @@ -75,6 +77,20 @@
"pp_chart2table": "llava",
"gemma3n_text": "qwen3_5_text",
"qwen3_5_moe_text": "qwen3_5_text",
"altclip_vision_model": "clip_vision_model",
"chinese_clip_vision_model": "clip_vision_model",
"clipseg_vision_model": "clip_vision_model",
"metaclip_2_vision_model": "clip_vision_model",
"mlcd_vision": "clip_vision_model",
"mlcd": "clip_vision_model",
"siglip_vision_model": "clip_vision_model",
"siglip2_vision_model": "clip_vision_model",
"xclip_vision_model": "clip_vision_model",
"clipseg_text_model": "clip_text_model",
"metaclip_2_text_model": "clip_text_model",
"siglip_text_model": "clip_text_model",
"siglip2_text_model": "clip_text_model",
"xclip_text_model": "clip_text_model",
}


Expand All @@ -96,6 +112,8 @@ def _build_checkpoint_conversion_mapping():
WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
WeightRenaming(source_patterns=r"^image_newline", target_patterns="model.image_newline"),
],
"clip_vision_model": [PrefixChange(prefix_to_remove="vision_model")],
"clip_text_model": [PrefixChange(prefix_to_remove="text_model")],
"video_llava": [
Comment thread
zucchini-nlp marked this conversation as resolved.
WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
Expand Down Expand Up @@ -568,17 +586,17 @@ def _build_checkpoint_conversion_mapping():
WeightRenaming(r"transf_decoder\._decoder\.final_layer_norm", r"decoder.norm"),
WeightRenaming(r"transf_decoder\._decoder\.layers", r"decoder.layers"),
WeightRenaming(r"encoder_decoder_proj\.", r"decoder.proj."),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"),
WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"),
WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"),
WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"),
WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.\1.self_attn.q_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.\1.self_attn.k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.\1.self_attn.v_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.\1.self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.\1.self_attn.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.\1.self_attn.bias_u"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.\1.self_attn.bias_v"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"),
WeightRenaming(r"\.second_sub_layer\.query_net", r".encoder_attn.q_proj"),
WeightRenaming(r"\.second_sub_layer\.key_net", r".encoder_attn.k_proj"),
WeightRenaming(r"\.second_sub_layer\.value_net", r".encoder_attn.v_proj"),
Expand Down Expand Up @@ -623,10 +641,16 @@ def register_checkpoint_conversion_mapping(
_checkpoint_conversion_mapping_cache[model_type] = mapping


def extract_weight_conversions_for_model(model: PreTrainedModel) -> list[WeightConverter | WeightRenaming] | None:
def extract_weight_conversions_for_model(model: PreTrainedModel, model_prefix: str) -> list[WeightTransform] | None:
model_type = getattr(model.config, "model_type", None)
if model_type is not None:
model_specific_conversions = get_checkpoint_conversion_mapping(model_type)
# In this case, add the prefix to `PrefixChange` instances, in order to know where to add/remove the prefix
if model_specific_conversions is not None and model_prefix != "":
for i, conversion in enumerate(model_specific_conversions):
# In this case, add the prefix, as otherwise we don't know where we need to re-add it exactly in the module name chain
if isinstance(conversion, PrefixChange):
model_specific_conversions[i] = conversion.with_submodel_prefix(model_prefix)
Comment thread
vasqu marked this conversation as resolved.
return model_specific_conversions
return None

Expand All @@ -636,7 +660,7 @@ def get_model_conversion_mapping(
key_mapping: dict[str, str] | None = None,
hf_quantizer: HfQuantizer | None = None,
add_legacy: bool = True,
) -> list[WeightConverter | WeightRenaming]:
) -> list[WeightTransform]:
"""
For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming
`_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping.
Expand All @@ -651,22 +675,13 @@ def get_model_conversion_mapping(
if key_mapping is not None:
weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()]

# Model have several `PreTrainedModel` within with the same model type
# For ex: XForConditionalGeneration -> XModel. We don't want to apply the same
# conversion pattern twice because of that
# Model have several `PreTrainedModel` within with the same model type, for example: XForConditionalGeneration -> XModel
# We don't want to apply the same conversion pattern twice because of that
seen_model_types = set()
if (conversions := extract_weight_conversions_for_model(model)) is not None:
weight_conversions.extend(conversions)
seen_model_types.add(model.config.model_type)

# Recurse over submodules and collect all conversions
for submodule in model.modules():
if (
submodule is not model
and isinstance(submodule, PreTrainedModel)
and submodule.config.model_type not in seen_model_types
):
conversions = extract_weight_conversions_for_model(submodule)
for name, submodule in model.named_modules():
if isinstance(submodule, PreTrainedModel) and submodule.config.model_type not in seen_model_types:
conversions = extract_weight_conversions_for_model(submodule, name)
if conversions is not None:
weight_conversions.extend(conversions)
seen_model_types.add(submodule.config.model_type)
Expand Down
167 changes: 129 additions & 38 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from itertools import chain
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -579,41 +578,40 @@ def process_source_pattern(source_pattern: str, target_pattern: str) -> str:
return source_pattern


@dataclass(slots=True)
Comment on lines 580 to -582
Copy link
Copy Markdown
Member Author

@Cyrilvallez Cyrilvallez Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were dataclasses but it did not make any sense, so removed it (but kept the slots, the only feature we were really using from dataclass - makes it much easier to inherit etc

class WeightTransform:
source_patterns: str | list[str] = field(init=True)
target_patterns: str | list[str] = field(init=True)
compiled_sources: re.Pattern = field(init=False)

distributed_operation: TensorParallelLayer | None = None
quantization_operation: ConversionOps | None = None

collected_tensors: dict[str, list[Future]] = field(default_factory=lambda: defaultdict(list), init=False)
layer_targets: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set), init=False)
# Restrict the attributes that can be attached
__slots__ = (
"source_patterns",
"target_patterns",
"compiled_sources",
"distributed_operation",
"quantization_operation",
"collected_tensors",
"layer_targets",
"_original_source_patterns",
"_original_target_patterns",
"_was_used",
)

# Those are needed to be able to reverse correctly the transform, as the patterns may be processed
_original_source_patterns: list[str] = field(init=False)
_original_target_patterns: list[str] = field(init=False)
def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]):
self.source_patterns: list[str] = source_patterns
self.target_patterns: list[str] = target_patterns
# Those are needed to be able to reverse correctly the transform, as the patterns may be processed
self._original_source_patterns = self.source_patterns.copy()
self._original_target_patterns = self.target_patterns.copy()

def __setattr__(self, name, value):
if name in ("source_patterns", "target_patterns"):
# We do not allow to re-set the patterns, as they are linked between each other and changing one
# without the other can mess-up with the capturing groups/compiled sources
if hasattr(self, name):
raise ValueError(f"Cannot assign to field {name}, you should create a new instance")
# Switch str to list
elif isinstance(value, str):
value = [value]
object.__setattr__(self, name, value)
# Init fields that will be used during conversion
self.distributed_operation: TensorParallelLayer | None = None
self.quantization_operation: ConversionOps | None = None
self.collected_tensors: dict[str, list[Future]] = defaultdict(list)
self.layer_targets: dict[str, set[str]] = defaultdict(set)

Comment thread
Cyrilvallez marked this conversation as resolved.
def __post_init__(self):
# Due to how our `_checkpoint_conversion_mapping` mappings are written, we need a few exceptions here
# when instantiating the reverse mapping (i.e. the targets become sources, and sources become targets)
# The issues lie in the sources usually, so here we need to check the targets for the reversed mapping
# Flag to notice if the Transform was used
self._was_used = False
Comment thread
vasqu marked this conversation as resolved.

# We need to copy the exact original patterns to later reverse (before processing may change them)
self._original_source_patterns = self.source_patterns.copy()
self._original_target_patterns = self.target_patterns.copy()
# We need to process a few exceptions here when instantiating the reverse mapping (i.e. the targets become
# sources, and sources become targets). The issues lie in the sources usually, so here we need to check the
# targets for the reversed mapping

# Process target_patterns: detect capturing groups and replace with \1
# Store the original capturing group patterns for reverse mapping
Expand Down Expand Up @@ -657,6 +655,20 @@ def __post_init__(self):
branches.append(f"(?P<{group_name}>{pattern})")
self.compiled_sources = re.compile("|".join(branches))

def __repr__(self):
return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})"

def __setattr__(self, name, value):
if name in ("source_patterns", "target_patterns"):
# We do not allow to re-set the patterns, as they are linked between each other and changing one
# without the other can mess-up with the capturing groups/compiled sources
if hasattr(self, name):
raise ValueError(f"Cannot assign to field {name}, you should create a new instance")
# Switch str to list
elif isinstance(value, str):
value = [value]
object.__setattr__(self, name, value)

def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future):
self.collected_tensors[source_pattern].append(future)
self.layer_targets[target_key].add(source_key)
Expand All @@ -673,6 +685,9 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
if match_object is None:
return source_key, None

# We have a match, so the Transform was used
self._was_used = True

# Find the source that produced the match (it's the first group that matched, as the search stops after first branch match)
matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None)
source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
Expand Down Expand Up @@ -731,11 +746,23 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:

return collected_tensors

def was_used(self) -> bool:
Comment thread
Cyrilvallez marked this conversation as resolved.
"""
Return whether the current Transform matched any weights during loading/saving. This is needed as some
weight renaming transforms are not bijective, i.e. if we drop/add full parts of a name with PrefixChange, we
lose some informations that we cannot get back if we don't know if the Transform was used before already (say we
have a prefix to drop, we need to know whether the checkpoints we loaded before contained the said prefix or not
before adding it back, or not, during saving).
"""
return self._was_used


@dataclass(slots=True)
class WeightRenaming(WeightTransform):
# Special case of WeightTransform that only renames keys without any conversion.

# Needs to be empty, otherwise the class will not be slotted
__slots__ = ()
Comment thread
vasqu marked this conversation as resolved.

def convert(
self,
layer_name: str,
Expand Down Expand Up @@ -770,19 +797,72 @@ def convert(
return collected_tensors


class PrefixChange(WeightRenaming):
"""
Special case of WeightRenaming, used to simplify adding/removing full parts of a weight name. The regexes
that are needed for such operations are complex, so this is a much easier API for such cases.
"""
Comment thread
Cyrilvallez marked this conversation as resolved.

__slots__ = (
"prefix_to_add",
"prefix_to_remove",
"model_prefix",
)

def __init__(
self, prefix_to_add: str | None = None, prefix_to_remove: str | None = None, model_prefix: str | None = None
Comment thread
Cyrilvallez marked this conversation as resolved.
):
if (prefix_to_add is None) ^ (prefix_to_remove is not None):
raise ValueError("You must provide only one of `prefix_to_add` and `prefix_to_remove`")

self.prefix_to_add = prefix_to_add
self.prefix_to_remove = prefix_to_remove
self.model_prefix = "" if model_prefix is None else model_prefix
prefix = rf"{self.model_prefix}\." if self.model_prefix != "" else ""

if prefix_to_add is not None:
super().__init__(
# We use a lookbehind to avoid adding the prefix if we detect that it's already present
source_patterns=rf"^{prefix}(?:(?!{prefix_to_add}\.))(.+)$",
target_patterns=rf"{prefix}{prefix_to_add}\.\1",
)
else:
super().__init__(source_patterns=rf"^{prefix}{prefix_to_remove}\.(.+)$", target_patterns=rf"{prefix}\1")
Comment thread
vasqu marked this conversation as resolved.

def reverse_transform(self) -> WeightTransform:
"""Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations."""
# TODO: check this and relax when quantizer have `reverse_op`
if self.quantization_operation is not None:
raise ValueError("Cannot reverse the transform with TP or quantization")

# Only one of the 2 can ever be used, so 1 is always None
return PrefixChange(
prefix_to_add=self.prefix_to_remove, prefix_to_remove=self.prefix_to_add, model_prefix=self.model_prefix
)

def with_submodel_prefix(self, prefix: str) -> PrefixChange:
new_prefix = f"{prefix}.{self.model_prefix}" if self.model_prefix != "" else prefix
return PrefixChange(
prefix_to_add=self.prefix_to_add, prefix_to_remove=self.prefix_to_remove, model_prefix=new_prefix
)


# List of classes that are known to be able to use m:n
_INTERNAL_MANY_TO_MANY_CONVERSIONS = (
ErnieFuseAndSplitTextVisionExperts,
ErnieSplitAndDecoupleTextVisionExperts,
)


@dataclass(slots=True)
class WeightConverter(WeightTransform):
operations: list[ConversionOps] = field(default_factory=list, repr=False)
__slots__ = ("operations",)

def __init__(
self, source_patterns: str | list[str], target_patterns: str | list[str], operations: list[ConversionOps]
):
super().__init__(source_patterns, target_patterns)
self.operations: list[ConversionOps] = operations

def __post_init__(self):
WeightTransform.__post_init__(self)
if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2:
# We allow many-to-many only if we use an internal operation that can handle it
if not any(isinstance(op, _INTERNAL_MANY_TO_MANY_CONVERSIONS) for op in self.operations):
Expand Down Expand Up @@ -1342,8 +1422,11 @@ def convert_and_load_state_dict_in_model(
# `cancel_futures=True` in case the program was interrupted, to avoid wasting time on exit
thread_pool.shutdown(wait=False, cancel_futures=True)

# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
model._weight_conversions = weight_mapping
# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but
# only if it was used, i.e. it matched any weight from the checkpoints
model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()]
Comment on lines +1425 to +1427
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not create another loophole if some conversions aren't used by model? But prob if the test checks full list of conversions, we can consciously decide to add task-specific patterns like lm_head

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I can follow, do you have a small example where you think we would fall through?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test that checks that all conversions are used (to make sure we don't add useless conversions) does it by checking the keys between original model and saved weights (because from our tiny models, we only know the "correct" keys, and the saved keys mimic the "wrong" keys), so it will still check it correctly!

model._weight_conversions = model_specific_conversions

return loading_info, disk_offload_index


Expand All @@ -1360,12 +1443,20 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch

# Do not resave with the legacy renaming, if present
weight_conversions = get_model_conversion_mapping(model, add_legacy=False)
# If the model had no `_weight_conversions` attached, drop any PrefixChange transform - this is because the
# model was almost surely instantiated from scratch (at least not from `from_pretrained`), and PrefixChange with
# `prefix_to_remove` would otherwise add a unwanted prefix (as we dont have any information about whether the prefix
# was there or not during load)
weight_conversions = [x for x in weight_conversions if not isinstance(x, PrefixChange)]
weight_conversions = weight_conversions if len(weight_conversions) > 0 else None

Comment thread
zucchini-nlp marked this conversation as resolved.
# We did not find any operations to perform -> quick escape
if weight_conversions is None:
return state_dict

# Important: we need to revert the order here, so that potential conversions from submodels are performed first
weight_conversions = weight_conversions[::-1]

# Reverse all Transform to correctly match keys
reverse_weight_conversion = [conversion.reverse_transform() for conversion in weight_conversions]
# If we are still here, we need to create the (reverse) conversion mapping from scratch
Expand Down
Loading
Loading