Support any capturing groups in WeightTransform reverse mapping#43205
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Very nice, happy to move this bit of logic to a proper function as it becomes a bit complex!
We need to be very careful about the logic though, see also my comment, i.e.
in all generality a purely index-based matching will break with your current implem with a Transform with different number of sources and targets (e.g. Operation(sources=r"blabla.(\d+)", targets=["bla1.\1", "bla2.\1"]) will break during reverse op, even though it is correctly defined in terms of patterns replacement as we have only 1 captured pattern)
If we have several sources with different capturing groups, then this is much harder to correctly match back in general (I believe it's actually non-determined in general as we cannot know what to match to what)
But so TLDR, let's make sure we only have 1 pattern inside target_capturing_groups at the end, and raise if we have several for now, and always replace with this unique pattern on ALL the sources
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| def reverse_target_pattern(pattern: str) -> tuple[str, str | None]: |
There was a problem hiding this comment.
Name is a bit misleading IMO!
| def reverse_target_pattern(pattern: str) -> tuple[str, str | None]: | |
| def process_target_pattern(pattern: str) -> tuple[str, str | None]: |
| # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3) | ||
| capturing_group_match = re.search(r"\([^)]+\)", pattern) |
There was a problem hiding this comment.
A slightly simpler version such as this one should work the same way no?
| # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3) | |
| capturing_group_match = re.search(r"\([^)]+\)", pattern) | |
| # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3) | |
| capturing_group_match = re.search(r"\(.+?\)", pattern) |
| pattern = pattern.replace(r"\1", r"(.+)") | ||
| # Use the stored capturing group from target_patterns | ||
| pattern = pattern.replace(r"\1", target_capturing_groups[capturing_groups_index], 1) | ||
| capturing_groups_index += 1 |
There was a problem hiding this comment.
Here in all generality a purely index-based matching will break with a Transform with different number of sources and targets (e.g. Operation(sources=r"blabla.(\d+)", targets=["bla1.\1", "bla2.\1"]) will break during reverse op, even though it is correctly defined in terms of patterns replacement as we have only 1 captured pattern)
If we have several sources with different capturing groups, then this is much harder to correctly match back in general (I believe it's actually non-determined in general as we cannot know what to match to what)
But so TLDR, let's make sure we only have 1 pattern inside target_capturing_groups at the end, and raise if we have several for now, and always replace with this unique pattern on ALL the sources
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43205&sha=43b550 |
…ingface#43205) * support any capturing groups in reverse mapping * define utils and fix test_reverse_loading_mapping * Fix test_reverse_loading_mapping * fix non deterministic behavior. * nit
Fix reverse mapping for capturing groups in WeightTransform
Previously assumed capturing groups were always
(.+)when restoring\1in reverse mapping, breaking patterns like(\d+)used by DETR in this PR.Now stores actual capturing group patterns and correctly restore them during reverse transforms.
Also adds a fix to
test_reverse_loading_mappingintest_modeling_common.py, to reverse target pattern before using it in a regex search, also needed for #41549Cc @Cyrilvallez