-
Notifications
You must be signed in to change notification settings - Fork 33.1k
[loading] Clean way to add/remove full parts in checkpoint names #45448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6069e63
b9c885f
9003fd8
2572a25
d0f7fb2
11bc494
8de7e0f
7f38c23
b98194d
4c05d6e
b2f8cc8
80d8386
d3cc313
e925ce8
e792532
c48218d
01cda19
e38bad1
2014ee1
f8acd0a
b0f5c26
e3bc9e8
2cb1633
d02310f
fa3bc47
3ba24ba
33917c1
6ee1496
a534dfb
b0bc342
b80d9ff
f8d2dff
c2a1b6d
fd1d8e7
a19ad84
f168035
94eaeb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,6 @@ | |
| from concurrent.futures import Future, ThreadPoolExecutor | ||
| from contextlib import contextmanager | ||
| from copy import deepcopy | ||
| from dataclasses import dataclass, field | ||
| from itertools import chain | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
|
|
@@ -579,41 +578,40 @@ def process_source_pattern(source_pattern: str, target_pattern: str) -> str: | |
| return source_pattern | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
|
Comment on lines
580
to
-582
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They were |
||
| class WeightTransform: | ||
| source_patterns: str | list[str] = field(init=True) | ||
| target_patterns: str | list[str] = field(init=True) | ||
| compiled_sources: re.Pattern = field(init=False) | ||
|
|
||
| distributed_operation: TensorParallelLayer | None = None | ||
| quantization_operation: ConversionOps | None = None | ||
|
|
||
| collected_tensors: dict[str, list[Future]] = field(default_factory=lambda: defaultdict(list), init=False) | ||
| layer_targets: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set), init=False) | ||
| # Restrict the attributes that can be attached | ||
| __slots__ = ( | ||
| "source_patterns", | ||
| "target_patterns", | ||
| "compiled_sources", | ||
| "distributed_operation", | ||
| "quantization_operation", | ||
| "collected_tensors", | ||
| "layer_targets", | ||
| "_original_source_patterns", | ||
| "_original_target_patterns", | ||
| "_was_used", | ||
| ) | ||
|
|
||
| # Those are needed to be able to reverse correctly the transform, as the patterns may be processed | ||
| _original_source_patterns: list[str] = field(init=False) | ||
| _original_target_patterns: list[str] = field(init=False) | ||
| def __init__(self, source_patterns: str | list[str], target_patterns: str | list[str]): | ||
| self.source_patterns: list[str] = source_patterns | ||
| self.target_patterns: list[str] = target_patterns | ||
| # Those are needed to be able to reverse correctly the transform, as the patterns may be processed | ||
| self._original_source_patterns = self.source_patterns.copy() | ||
| self._original_target_patterns = self.target_patterns.copy() | ||
|
|
||
| def __setattr__(self, name, value): | ||
| if name in ("source_patterns", "target_patterns"): | ||
| # We do not allow to re-set the patterns, as they are linked between each other and changing one | ||
| # without the other can mess-up with the capturing groups/compiled sources | ||
| if hasattr(self, name): | ||
| raise ValueError(f"Cannot assign to field {name}, you should create a new instance") | ||
| # Switch str to list | ||
| elif isinstance(value, str): | ||
| value = [value] | ||
| object.__setattr__(self, name, value) | ||
| # Init fields that will be used during conversion | ||
| self.distributed_operation: TensorParallelLayer | None = None | ||
| self.quantization_operation: ConversionOps | None = None | ||
| self.collected_tensors: dict[str, list[Future]] = defaultdict(list) | ||
| self.layer_targets: dict[str, set[str]] = defaultdict(set) | ||
|
|
||
|
Cyrilvallez marked this conversation as resolved.
|
||
| def __post_init__(self): | ||
| # Due to how our `_checkpoint_conversion_mapping` mappings are written, we need a few exceptions here | ||
| # when instantiating the reverse mapping (i.e. the targets become sources, and sources become targets) | ||
| # The issues lie in the sources usually, so here we need to check the targets for the reversed mapping | ||
| # Flag to notice if the Transform was used | ||
| self._was_used = False | ||
|
vasqu marked this conversation as resolved.
|
||
|
|
||
| # We need to copy the exact original patterns to later reverse (before processing may change them) | ||
| self._original_source_patterns = self.source_patterns.copy() | ||
| self._original_target_patterns = self.target_patterns.copy() | ||
| # We need to process a few exceptions here when instantiating the reverse mapping (i.e. the targets become | ||
| # sources, and sources become targets). The issues lie in the sources usually, so here we need to check the | ||
| # targets for the reversed mapping | ||
|
|
||
| # Process target_patterns: detect capturing groups and replace with \1 | ||
| # Store the original capturing group patterns for reverse mapping | ||
|
|
@@ -657,6 +655,20 @@ def __post_init__(self): | |
| branches.append(f"(?P<{group_name}>{pattern})") | ||
| self.compiled_sources = re.compile("|".join(branches)) | ||
|
|
||
| def __repr__(self): | ||
| return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})" | ||
|
|
||
| def __setattr__(self, name, value): | ||
| if name in ("source_patterns", "target_patterns"): | ||
| # We do not allow to re-set the patterns, as they are linked between each other and changing one | ||
| # without the other can mess-up with the capturing groups/compiled sources | ||
| if hasattr(self, name): | ||
| raise ValueError(f"Cannot assign to field {name}, you should create a new instance") | ||
| # Switch str to list | ||
| elif isinstance(value, str): | ||
| value = [value] | ||
| object.__setattr__(self, name, value) | ||
|
|
||
| def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future): | ||
| self.collected_tensors[source_pattern].append(future) | ||
| self.layer_targets[target_key].add(source_key) | ||
|
|
@@ -673,6 +685,9 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]: | |
| if match_object is None: | ||
| return source_key, None | ||
|
|
||
| # We have a match, so the Transform was used | ||
| self._was_used = True | ||
|
|
||
| # Find the source that produced the match (it's the first group that matched, as the search stops after first branch match) | ||
| matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None) | ||
| source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])] | ||
|
|
@@ -731,11 +746,23 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: | |
|
|
||
| return collected_tensors | ||
|
|
||
| def was_used(self) -> bool: | ||
|
Cyrilvallez marked this conversation as resolved.
|
||
| """ | ||
| Return whether the current Transform matched any weights during loading/saving. This is needed as some | ||
| weight renaming transforms are not bijective, i.e. if we drop/add full parts of a name with PrefixChange, we | ||
| lose some informations that we cannot get back if we don't know if the Transform was used before already (say we | ||
| have a prefix to drop, we need to know whether the checkpoints we loaded before contained the said prefix or not | ||
| before adding it back, or not, during saving). | ||
| """ | ||
| return self._was_used | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class WeightRenaming(WeightTransform): | ||
| # Special case of WeightTransform that only renames keys without any conversion. | ||
|
|
||
| # Needs to be empty, otherwise the class will not be slotted | ||
| __slots__ = () | ||
|
vasqu marked this conversation as resolved.
|
||
|
|
||
| def convert( | ||
| self, | ||
| layer_name: str, | ||
|
|
@@ -770,19 +797,72 @@ def convert( | |
| return collected_tensors | ||
|
|
||
|
|
||
| class PrefixChange(WeightRenaming): | ||
| """ | ||
| Special case of WeightRenaming, used to simplify adding/removing full parts of a weight name. The regexes | ||
| that are needed for such operations are complex, so this is a much easier API for such cases. | ||
| """ | ||
|
Cyrilvallez marked this conversation as resolved.
|
||
|
|
||
| __slots__ = ( | ||
| "prefix_to_add", | ||
| "prefix_to_remove", | ||
| "model_prefix", | ||
| ) | ||
|
|
||
| def __init__( | ||
| self, prefix_to_add: str | None = None, prefix_to_remove: str | None = None, model_prefix: str | None = None | ||
|
Cyrilvallez marked this conversation as resolved.
|
||
| ): | ||
| if (prefix_to_add is None) ^ (prefix_to_remove is not None): | ||
| raise ValueError("You must provide only one of `prefix_to_add` and `prefix_to_remove`") | ||
|
|
||
| self.prefix_to_add = prefix_to_add | ||
| self.prefix_to_remove = prefix_to_remove | ||
| self.model_prefix = "" if model_prefix is None else model_prefix | ||
| prefix = rf"{self.model_prefix}\." if self.model_prefix != "" else "" | ||
|
|
||
| if prefix_to_add is not None: | ||
| super().__init__( | ||
| # We use a lookbehind to avoid adding the prefix if we detect that it's already present | ||
| source_patterns=rf"^{prefix}(?:(?!{prefix_to_add}\.))(.+)$", | ||
| target_patterns=rf"{prefix}{prefix_to_add}\.\1", | ||
| ) | ||
| else: | ||
| super().__init__(source_patterns=rf"^{prefix}{prefix_to_remove}\.(.+)$", target_patterns=rf"{prefix}\1") | ||
|
vasqu marked this conversation as resolved.
|
||
|
|
||
| def reverse_transform(self) -> WeightTransform: | ||
| """Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations.""" | ||
| # TODO: check this and relax when quantizer have `reverse_op` | ||
| if self.quantization_operation is not None: | ||
| raise ValueError("Cannot reverse the transform with TP or quantization") | ||
|
|
||
| # Only one of the 2 can ever be used, so 1 is always None | ||
| return PrefixChange( | ||
| prefix_to_add=self.prefix_to_remove, prefix_to_remove=self.prefix_to_add, model_prefix=self.model_prefix | ||
| ) | ||
|
|
||
| def with_submodel_prefix(self, prefix: str) -> PrefixChange: | ||
| new_prefix = f"{prefix}.{self.model_prefix}" if self.model_prefix != "" else prefix | ||
| return PrefixChange( | ||
| prefix_to_add=self.prefix_to_add, prefix_to_remove=self.prefix_to_remove, model_prefix=new_prefix | ||
| ) | ||
|
|
||
|
|
||
| # List of classes that are known to be able to use m:n | ||
| _INTERNAL_MANY_TO_MANY_CONVERSIONS = ( | ||
| ErnieFuseAndSplitTextVisionExperts, | ||
| ErnieSplitAndDecoupleTextVisionExperts, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class WeightConverter(WeightTransform): | ||
| operations: list[ConversionOps] = field(default_factory=list, repr=False) | ||
| __slots__ = ("operations",) | ||
|
|
||
| def __init__( | ||
| self, source_patterns: str | list[str], target_patterns: str | list[str], operations: list[ConversionOps] | ||
| ): | ||
| super().__init__(source_patterns, target_patterns) | ||
| self.operations: list[ConversionOps] = operations | ||
|
|
||
| def __post_init__(self): | ||
| WeightTransform.__post_init__(self) | ||
| if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2: | ||
| # We allow many-to-many only if we use an internal operation that can handle it | ||
| if not any(isinstance(op, _INTERNAL_MANY_TO_MANY_CONVERSIONS) for op in self.operations): | ||
|
|
@@ -1342,8 +1422,11 @@ def convert_and_load_state_dict_in_model( | |
| # `cancel_futures=True` in case the program was interrupted, to avoid wasting time on exit | ||
| thread_pool.shutdown(wait=False, cancel_futures=True) | ||
|
|
||
| # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user) | ||
| model._weight_conversions = weight_mapping | ||
| # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but | ||
| # only if it was used, i.e. it matched any weight from the checkpoints | ||
| model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()] | ||
|
Comment on lines
+1425
to
+1427
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it not create another loophole if some conversions aren't used by model? But prob if the test checks full list of conversions, we can consciously decide to add task-specific patterns like
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I can follow, do you have a small example where you think we would fall through?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test that checks that all conversions are used (to make sure we don't add useless conversions) does it by checking the keys between original model and saved weights (because from our tiny models, we only know the "correct" keys, and the saved keys mimic the "wrong" keys), so it will still check it correctly! |
||
| model._weight_conversions = model_specific_conversions | ||
|
|
||
| return loading_info, disk_offload_index | ||
|
|
||
|
|
||
|
|
@@ -1360,12 +1443,20 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch | |
|
|
||
| # Do not resave with the legacy renaming, if present | ||
| weight_conversions = get_model_conversion_mapping(model, add_legacy=False) | ||
| # If the model had no `_weight_conversions` attached, drop any PrefixChange transform - this is because the | ||
| # model was almost surely instantiated from scratch (at least not from `from_pretrained`), and PrefixChange with | ||
| # `prefix_to_remove` would otherwise add a unwanted prefix (as we dont have any information about whether the prefix | ||
| # was there or not during load) | ||
| weight_conversions = [x for x in weight_conversions if not isinstance(x, PrefixChange)] | ||
| weight_conversions = weight_conversions if len(weight_conversions) > 0 else None | ||
|
|
||
|
zucchini-nlp marked this conversation as resolved.
|
||
| # We did not find any operations to perform -> quick escape | ||
| if weight_conversions is None: | ||
| return state_dict | ||
|
|
||
| # Important: we need to revert the order here, so that potential conversions from submodels are performed first | ||
| weight_conversions = weight_conversions[::-1] | ||
|
|
||
| # Reverse all Transform to correctly match keys | ||
| reverse_weight_conversion = [conversion.reverse_transform() for conversion in weight_conversions] | ||
| # If we are still here, we need to create the (reverse) conversion mapping from scratch | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.