diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 2a6dc23ba9d0..88155e4b4daa 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -22,9 +22,11 @@ Concatenate, ErnieFuseAndSplitTextVisionExperts, MergeModulelist, + PrefixChange, Transpose, WeightConverter, WeightRenaming, + WeightTransform, ) @@ -75,6 +77,20 @@ "pp_chart2table": "llava", "gemma3n_text": "qwen3_5_text", "qwen3_5_moe_text": "qwen3_5_text", + "altclip_vision_model": "clip_vision_model", + "chinese_clip_vision_model": "clip_vision_model", + "clipseg_vision_model": "clip_vision_model", + "metaclip_2_vision_model": "clip_vision_model", + "mlcd_vision": "clip_vision_model", + "mlcd": "clip_vision_model", + "siglip_vision_model": "clip_vision_model", + "siglip2_vision_model": "clip_vision_model", + "xclip_vision_model": "clip_vision_model", + "clipseg_text_model": "clip_text_model", + "metaclip_2_text_model": "clip_text_model", + "siglip_text_model": "clip_text_model", + "siglip2_text_model": "clip_text_model", + "xclip_text_model": "clip_text_model", } @@ -96,6 +112,8 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), WeightRenaming(source_patterns=r"^image_newline", target_patterns="model.image_newline"), ], + "clip_vision_model": [PrefixChange(prefix_to_remove="vision_model")], + "clip_text_model": [PrefixChange(prefix_to_remove="text_model")], "video_llava": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), @@ -568,17 +586,17 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(r"transf_decoder\._decoder\.final_layer_norm", r"decoder.norm"), WeightRenaming(r"transf_decoder\._decoder\.layers", r"decoder.layers"), WeightRenaming(r"encoder_decoder_proj\.", r"decoder.proj."), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"), - WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"), - WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"), - WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"), - WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"), - WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.\1.self_attn.q_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.\1.self_attn.k_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.\1.self_attn.v_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.\1.self_attn.o_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.\1.self_attn.relative_k_proj"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.\1.self_attn.bias_u"), + WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.\1.self_attn.bias_v"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"), + WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"), WeightRenaming(r"\.second_sub_layer\.query_net", r".encoder_attn.q_proj"), WeightRenaming(r"\.second_sub_layer\.key_net", r".encoder_attn.k_proj"), WeightRenaming(r"\.second_sub_layer\.value_net", r".encoder_attn.v_proj"), @@ -623,10 +641,16 @@ def register_checkpoint_conversion_mapping( _checkpoint_conversion_mapping_cache[model_type] = mapping -def extract_weight_conversions_for_model(model: PreTrainedModel) -> list[WeightConverter | WeightRenaming] | None: +def extract_weight_conversions_for_model(model: PreTrainedModel, model_prefix: str) -> list[WeightTransform] | None: model_type = getattr(model.config, "model_type", None) if model_type is not None: model_specific_conversions = get_checkpoint_conversion_mapping(model_type) + # In this case, add the prefix to `PrefixChange` instances, in order to know where to add/remove the prefix + if model_specific_conversions is not None and model_prefix != "": + for i, conversion in enumerate(model_specific_conversions): + # In this case, add the prefix, as otherwise we don't know where we need to re-add it exactly in the module name chain + if isinstance(conversion, PrefixChange): + model_specific_conversions[i] = conversion.with_submodel_prefix(model_prefix) return model_specific_conversions return None @@ -636,7 +660,7 @@ def get_model_conversion_mapping( key_mapping: dict[str, str] | None = None, hf_quantizer: HfQuantizer | None = None, add_legacy: bool = True, -) -> list[WeightConverter | WeightRenaming]: +) -> list[WeightTransform]: """ For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming `_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping. @@ -651,22 +675,13 @@ def get_model_conversion_mapping( if key_mapping is not None: weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()] - # Model have several `PreTrainedModel` within with the same model type - # For ex: XForConditionalGeneration -> XModel. We don't want to apply the same - # conversion pattern twice because of that + # Model have several `PreTrainedModel` within with the same model type, for example: XForConditionalGeneration -> XModel + # We don't want to apply the same conversion pattern twice because of that seen_model_types = set() - if (conversions := extract_weight_conversions_for_model(model)) is not None: - weight_conversions.extend(conversions) - seen_model_types.add(model.config.model_type) - # Recurse over submodules and collect all conversions - for submodule in model.modules(): - if ( - submodule is not model - and isinstance(submodule, PreTrainedModel) - and submodule.config.model_type not in seen_model_types - ): - conversions = extract_weight_conversions_for_model(submodule) + for name, submodule in model.named_modules(): + if isinstance(submodule, PreTrainedModel) and submodule.config.model_type not in seen_model_types: + conversions = extract_weight_conversions_for_model(submodule, name) if conversions is not None: weight_conversions.extend(conversions) seen_model_types.add(submodule.config.model_type) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index b43d5354e8ac..cd0710649c91 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -25,7 +25,6 @@ from concurrent.futures import Future, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from dataclasses import dataclass, field from itertools import chain from typing import TYPE_CHECKING, Any @@ -579,41 +578,40 @@ def process_source_pattern(source_pattern: str, target_pattern: str) -> str: return source_pattern -@dataclass(slots=True) class WeightTransform: - source_patterns: str | list[str] = field(init=True) - target_patterns: str | list[str] = field(init=True) - compiled_sources: re.Pattern = field(init=False) - - distributed_operation: TensorParallelLayer | None = None - quantization_operation: ConversionOps | None = None - - collected_tensors: dict[str, list[Future]] = field(default_factory=lambda: defaultdict(list), init=False) - layer_targets: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set), init=False) + # Restrict the attributes that can be attached + __slots__ = ( + "source_patterns", + "target_patterns", + "compiled_sources", + "distributed_operation", + "quantization_operation", + "collected_tensors", + "layer_targets", + "_original_source_patterns", + "_original_target_patterns", + "_was_used", + ) - # Those are needed to be able to reverse correctly the transform, as the patterns may be processed - _original_source_patterns: list[str] = field(init=False) - _original_target_patterns: list[str] = field(init=False) + def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): + self.source_patterns: list[str] = source_patterns + self.target_patterns: list[str] = target_patterns + # Those are needed to be able to reverse correctly the transform, as the patterns may be processed + self._original_source_patterns = self.source_patterns.copy() + self._original_target_patterns = self.target_patterns.copy() - def __setattr__(self, name, value): - if name in ("source_patterns", "target_patterns"): - # We do not allow to re-set the patterns, as they are linked between each other and changing one - # without the other can mess-up with the capturing groups/compiled sources - if hasattr(self, name): - raise ValueError(f"Cannot assign to field {name}, you should create a new instance") - # Switch str to list - elif isinstance(value, str): - value = [value] - object.__setattr__(self, name, value) + # Init fields that will be used during conversion + self.distributed_operation: TensorParallelLayer | None = None + self.quantization_operation: ConversionOps | None = None + self.collected_tensors: dict[str, list[Future]] = defaultdict(list) + self.layer_targets: dict[str, set[str]] = defaultdict(set) - def __post_init__(self): - # Due to how our `_checkpoint_conversion_mapping` mappings are written, we need a few exceptions here - # when instantiating the reverse mapping (i.e. the targets become sources, and sources become targets) - # The issues lie in the sources usually, so here we need to check the targets for the reversed mapping + # Flag to notice if the Transform was used + self._was_used = False - # We need to copy the exact original patterns to later reverse (before processing may change them) - self._original_source_patterns = self.source_patterns.copy() - self._original_target_patterns = self.target_patterns.copy() + # We need to process a few exceptions here when instantiating the reverse mapping (i.e. the targets become + # sources, and sources become targets). The issues lie in the sources usually, so here we need to check the + # targets for the reversed mapping # Process target_patterns: detect capturing groups and replace with \1 # Store the original capturing group patterns for reverse mapping @@ -657,6 +655,20 @@ def __post_init__(self): branches.append(f"(?P<{group_name}>{pattern})") self.compiled_sources = re.compile("|".join(branches)) + def __repr__(self): + return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})" + + def __setattr__(self, name, value): + if name in ("source_patterns", "target_patterns"): + # We do not allow to re-set the patterns, as they are linked between each other and changing one + # without the other can mess-up with the capturing groups/compiled sources + if hasattr(self, name): + raise ValueError(f"Cannot assign to field {name}, you should create a new instance") + # Switch str to list + elif isinstance(value, str): + value = [value] + object.__setattr__(self, name, value) + def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future): self.collected_tensors[source_pattern].append(future) self.layer_targets[target_key].add(source_key) @@ -673,6 +685,9 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]: if match_object is None: return source_key, None + # We have a match, so the Transform was used + self._was_used = True + # Find the source that produced the match (it's the first group that matched, as the search stops after first branch match) matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None) source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])] @@ -731,11 +746,23 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: return collected_tensors + def was_used(self) -> bool: + """ + Return whether the current Transform matched any weights during loading/saving. This is needed as some + weight renaming transforms are not bijective, i.e. if we drop/add full parts of a name with PrefixChange, we + lose some informations that we cannot get back if we don't know if the Transform was used before already (say we + have a prefix to drop, we need to know whether the checkpoints we loaded before contained the said prefix or not + before adding it back, or not, during saving). + """ + return self._was_used + -@dataclass(slots=True) class WeightRenaming(WeightTransform): # Special case of WeightTransform that only renames keys without any conversion. + # Needs to be empty, otherwise the class will not be slotted + __slots__ = () + def convert( self, layer_name: str, @@ -770,6 +797,56 @@ def convert( return collected_tensors +class PrefixChange(WeightRenaming): + """ + Special case of WeightRenaming, used to simplify adding/removing full parts of a weight name. The regexes + that are needed for such operations are complex, so this is a much easier API for such cases. + """ + + __slots__ = ( + "prefix_to_add", + "prefix_to_remove", + "model_prefix", + ) + + def __init__( + self, prefix_to_add: str | None = None, prefix_to_remove: str | None = None, model_prefix: str | None = None + ): + if (prefix_to_add is None) ^ (prefix_to_remove is not None): + raise ValueError("You must provide only one of `prefix_to_add` and `prefix_to_remove`") + + self.prefix_to_add = prefix_to_add + self.prefix_to_remove = prefix_to_remove + self.model_prefix = "" if model_prefix is None else model_prefix + prefix = rf"{self.model_prefix}\." if self.model_prefix != "" else "" + + if prefix_to_add is not None: + super().__init__( + # We use a lookbehind to avoid adding the prefix if we detect that it's already present + source_patterns=rf"^{prefix}(?:(?!{prefix_to_add}\.))(.+)$", + target_patterns=rf"{prefix}{prefix_to_add}\.\1", + ) + else: + super().__init__(source_patterns=rf"^{prefix}{prefix_to_remove}\.(.+)$", target_patterns=rf"{prefix}\1") + + def reverse_transform(self) -> WeightTransform: + """Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations.""" + # TODO: check this and relax when quantizer have `reverse_op` + if self.quantization_operation is not None: + raise ValueError("Cannot reverse the transform with TP or quantization") + + # Only one of the 2 can ever be used, so 1 is always None + return PrefixChange( + prefix_to_add=self.prefix_to_remove, prefix_to_remove=self.prefix_to_add, model_prefix=self.model_prefix + ) + + def with_submodel_prefix(self, prefix: str) -> PrefixChange: + new_prefix = f"{prefix}.{self.model_prefix}" if self.model_prefix != "" else prefix + return PrefixChange( + prefix_to_add=self.prefix_to_add, prefix_to_remove=self.prefix_to_remove, model_prefix=new_prefix + ) + + # List of classes that are known to be able to use m:n _INTERNAL_MANY_TO_MANY_CONVERSIONS = ( ErnieFuseAndSplitTextVisionExperts, @@ -777,12 +854,15 @@ def convert( ) -@dataclass(slots=True) class WeightConverter(WeightTransform): - operations: list[ConversionOps] = field(default_factory=list, repr=False) + __slots__ = ("operations",) + + def __init__( + self, source_patterns: str | list[str], target_patterns: str | list[str], operations: list[ConversionOps] + ): + super().__init__(source_patterns, target_patterns) + self.operations: list[ConversionOps] = operations - def __post_init__(self): - WeightTransform.__post_init__(self) if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2: # We allow many-to-many only if we use an internal operation that can handle it if not any(isinstance(op, _INTERNAL_MANY_TO_MANY_CONVERSIONS) for op in self.operations): @@ -1342,8 +1422,11 @@ def convert_and_load_state_dict_in_model( # `cancel_futures=True` in case the program was interrupted, to avoid wasting time on exit thread_pool.shutdown(wait=False, cancel_futures=True) - # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user) - model._weight_conversions = weight_mapping + # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but + # only if it was used, i.e. it matched any weight from the checkpoints + model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()] + model._weight_conversions = model_specific_conversions + return loading_info, disk_offload_index @@ -1360,12 +1443,20 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch # Do not resave with the legacy renaming, if present weight_conversions = get_model_conversion_mapping(model, add_legacy=False) + # If the model had no `_weight_conversions` attached, drop any PrefixChange transform - this is because the + # model was almost surely instantiated from scratch (at least not from `from_pretrained`), and PrefixChange with + # `prefix_to_remove` would otherwise add a unwanted prefix (as we dont have any information about whether the prefix + # was there or not during load) + weight_conversions = [x for x in weight_conversions if not isinstance(x, PrefixChange)] weight_conversions = weight_conversions if len(weight_conversions) > 0 else None # We did not find any operations to perform -> quick escape if weight_conversions is None: return state_dict + # Important: we need to revert the order here, so that potential conversions from submodels are performed first + weight_conversions = weight_conversions[::-1] + # Reverse all Transform to correctly match keys reverse_weight_conversion = [conversion.reverse_transform() for conversion in weight_conversions] # If we are still here, we need to create the (reverse) conversion mapping from scratch diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 77aeddc31b11..bf849c031e3a 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -13,11 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch AltCLIP model.""" -import copy import inspect -import os -import re -import tempfile import unittest import numpy as np @@ -25,8 +21,6 @@ from parameterized import parameterized from transformers import AltCLIPConfig, AltCLIPProcessor, AltCLIPTextConfig, AltCLIPVisionConfig -from transformers.conversion_mapping import get_model_conversion_mapping -from transformers.core_model_loading import WeightRenaming, process_target_pattern from transformers.testing_utils import is_flaky, require_torch, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_vision_available @@ -44,7 +38,6 @@ if is_torch_available(): import torch import torch.nn as nn - from safetensors.torch import load_file from transformers import AltCLIPModel, AltCLIPTextModel, AltCLIPVisionModel @@ -472,54 +465,6 @@ def test_model_from_pretrained(self): model = AltCLIPModel.from_pretrained(model_name) self.assertIsNotNone(model) - def test_reverse_loading_mapping(self, check_keys_were_modified=True): - # AltCLIP applies legacy conversion which is never reversed, so we won't get - # matching state dict keys after re-saving it back - - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - # Each individual model is a subtest - with self.subTest(model_class.__name__): - model = model_class(copy.deepcopy(config)) - # Skip if no conversions - conversions = get_model_conversion_mapping(model, add_legacy=False) - if len(conversions) == 0: - self.skipTest(f"No conversion found for {model_class}") - - # Find the model keys, so the targets according to the conversions - model_keys = list(model.state_dict().keys()) - - with tempfile.TemporaryDirectory() as tmpdirname: - # Serialize with reverse mapping - model.save_pretrained(tmpdirname) - state_dict = load_file(os.path.join(tmpdirname, "model.safetensors")) - # Get all the serialized keys that we just saved according to the reverse mapping - serialized_keys = list(state_dict.keys()) - - if check_keys_were_modified: - # They should be different, otherwise we did not perform any mapping - self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") - - # Check that for each conversion entry, we at least map to one key - for conversion in conversions: - for source_pattern in conversion.source_patterns: - # Sometimes the mappings specify keys that are tied, so absent from the saved state dict - if isinstance(conversion, WeightRenaming): - # We need to revert the target pattern to make it compatible with regex search - target_pattern_reversed = conversion.target_patterns[0] - captured_group = process_target_pattern(source_pattern)[1] - if captured_group: - target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group) - if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()): - continue - num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys) - self.assertTrue( - num_matches > 0, - f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " - "This indicates whether that the pattern is not properly written, or that it could not be reversed correctly", - ) - @require_vision @require_torch diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 24f278c24704..ae502023e7d9 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -44,7 +44,7 @@ set_seed, ) from transformers.conversion_mapping import get_model_conversion_mapping -from transformers.core_model_loading import WeightRenaming, process_target_pattern +from transformers.core_model_loading import PrefixChange, WeightRenaming, process_target_pattern from transformers.integrations import HfDeepSpeedConfig from transformers.integrations.deepspeed import ( is_deepspeed_available, @@ -4772,6 +4772,11 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True, skip_base_ conversions = get_model_conversion_mapping(model, add_legacy=False) if len(conversions) == 0: self.skipTest(f"No conversion found for {model_class}") + # The PrefixChange conersions are only there for BC with hub checkpoints, but cannot be tested + # for as we skip them automatically if they are not present in loaded checkpoints (we want to + # mess up the prefixes only if the loaded checkpoints were doing so as well) + if all(isinstance(conversion, PrefixChange) for conversion in conversions): + self.skipTest(f"Only PrefixChange conversions found for {model_class}") # Find the model keys, so the targets according to the conversions model_keys = list(model.state_dict().keys()) @@ -4789,6 +4794,11 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True, skip_base_ # Check that for each conversion entry, we at least map to one key for conversion in conversions: + # The PrefixChange conersions are only there for BC with hub checkpoints, but cannot be tested + # for as we skip them automatically if they are not present in loaded checkpoints (we want to + # mess up the prefixes only if the loaded checkpoints were doing so as well) + if isinstance(conversion, PrefixChange): + continue for source_pattern in conversion.source_patterns: # Some patterns are written for gen-model only and won't be applied on base model if "lm_head" in source_pattern and model_class not in [ @@ -5661,7 +5671,12 @@ def compare_state_dicts(state_dict1, state_dict2) -> bool: """Make sure 2 state dicts are the exact same""" # Make sure the keys are the exact same if sorted(state_dict1.keys()) != sorted(state_dict2.keys()): - raise ValueError("The keys of both state dict are not the same") + in1_not2 = sorted(set(state_dict1.keys()) - set(state_dict2.keys())) + in2_not1 = sorted(set(state_dict2.keys()) - set(state_dict1.keys())) + raise ValueError( + f"The keys of both state dict are not the same.\nKeys found in the first item but not second: {in1_not2}" + f"\nKeys found in the second item but not first: {in2_not1}" + ) for k, v1 in state_dict1.items(): v2 = state_dict2[k] diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 2875f44088a7..a358822d19f8 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest from types import SimpleNamespace @@ -27,6 +28,7 @@ LinearToConv3d, MergeModulelist, PermuteForRope, + PrefixChange, WeightConverter, WeightRenaming, build_glob_alternation, @@ -210,10 +212,11 @@ class DummyRoot(nn.Module): base_model_prefix = "model" config: PretrainedConfig - def __init__(self, add_extra_moe=False): + def __init__(self, add_extra_moe=False, with_mlp=True): super().__init__() self.model = DummyTopModel(add_extra_moe) - self.mlp = DummyMLP() + if with_mlp: + self.mlp = DummyMLP() class TestConvertAndLoadStateDict(unittest.TestCase): @@ -778,6 +781,245 @@ def test_register_checkpoint_conversion_mapping_overwrites(self): self.assertEqual(len(get_checkpoint_conversion_mapping("foobarbaz")), 2) + def test_can_remove_prefix(self): + model = DummyRoot() + model.config = PretrainedConfig() + + bad_serialized_checkpoints = {f"bad_name.{k}": v.clone() for k, v in model.state_dict().items()} + weight_mapping = [PrefixChange(prefix_to_remove="bad_name")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not re-add the bad prefix since it was + # not present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be added when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + + def test_can_add_prefix(self): + # we cannot have another param next to the model, otherwise the prefix adding will already be added even with correct + # checkpoints starting with the prefix + model = DummyRoot(with_mlp=False) + model.config = PretrainedConfig() + + bad_serialized_checkpoints = {k.removeprefix("model."): v.clone() for k, v in model.state_dict().items()} + weight_mapping = [PrefixChange(prefix_to_add="model")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not remove the prefix since it was + # already present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be removed when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + + def test_can_remove_prefix_submodule(self): + model = DummyRoot() + model.config = PretrainedConfig() + + bad_serialized_checkpoints = { + f"model.layers.bad_name.{k.replace('model.layers.', '')}" if "model.layers." in k else k: v.clone() + for k, v in model.state_dict().items() + } + weight_mapping = [PrefixChange(prefix_to_remove="bad_name", model_prefix="model.layers")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not re-add the bad prefix since it was + # not present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be added when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + + def test_can_add_prefix_submodule(self): + # we cannot have another param next to the model, otherwise the prefix adding will already be added even with correct + # checkpoints starting with the prefix + model = DummyRoot(with_mlp=False) + model.config = PretrainedConfig() + + bad_serialized_checkpoints = {k.replace(".layers.", "."): v.clone() for k, v in model.state_dict().items()} + weight_mapping = [PrefixChange(prefix_to_add="layers", model_prefix="model")] + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + bad_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, re-adding the bad prefix + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(bad_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == bad_serialized_checkpoints[k]).all()) + + # Now, check that using the same conversion with already good keys works when loading and resaving + good_serialized_checkpoints = {k: v.clone() for k, v in model.state_dict().items()} + + loading_info, _ = convert_and_load_state_dict_in_model( + model, + good_serialized_checkpoints, + LoadStateDictConfig(weight_mapping=copy.deepcopy(weight_mapping)), + tp_plan=None, + ) + + # Assert we can load without issues + self.assertEqual(loading_info.missing_keys, set()) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + # Assert that re-saving will lead to the exact same state_dict, i.e. it will not remove the prefix since it was + # already present at loading time + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + self.assertEqual(set(good_serialized_checkpoints.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == good_serialized_checkpoints[k]).all()) + + # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached + # and the prefix should not be removed when saving directly (i.e. the conversion should be dropped) + model = DummyRoot() + saved_state_dict = revert_weight_conversion(model, model.state_dict()) + model_state_dict = model.state_dict() + self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) + for k, v in saved_state_dict.items(): + self.assertTrue((v == model_state_dict[k]).all()) + if __name__ == "__main__": unittest.main()