From 2650a10944cbf2423b653054bee766b0fb17e698 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 7 Aug 2025 12:10:44 +0200 Subject: [PATCH 1/3] fix --- src/transformers/configuration_utils.py | 6 ++---- tests/test_modeling_common.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 70f9979a8f14..8bec1e3c1e20 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -410,16 +410,14 @@ def _attn_implementation(self): def _attn_implementation(self, value: Optional[Union[str, dict]]): """We set it recursively on the sub-configs as well""" # Set if for current config - attn_implementation = value if not isinstance(value, dict) else value.get("", self._attn_implementation) + attn_implementation = value if not isinstance(value, dict) else value.get("", None) self._attn_implementation_internal = attn_implementation # Set it recursively on the subconfigs for subconfig_key in self.sub_configs: subconfig = getattr(self, subconfig_key, None) if subconfig is not None: - sub_implementation = ( - value if not isinstance(value, dict) else value.get(subconfig_key, subconfig._attn_implementation) - ) + sub_implementation = value if not isinstance(value, dict) else value.get(subconfig_key, None) subconfig._attn_implementation = sub_implementation def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8081317505e8..3d87255437cf 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3663,6 +3663,20 @@ def test_attn_implementation_composite_models(self): model = model_class(config) self.assertTrue(model.config.get_text_config(decoder=True)._attn_implementation == "eager") + # Test that using `dict` atttention implementation works with `from_pretrained` + # Set all backbones to "eager" because "eager" attention is always available + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = model.from_pretrained(tmpdirname, attn_implementation=attn_implementation_per_subconfig) + self.assertTrue(new_model.config._attn_implementation == "eager") + for submodule in new_model.modules(): + if ( + submodule is not new_model + and isinstance(submodule, PreTrainedModel) + and submodule.config.__class__ != new_model.config.__class__ + ): + self.assertTrue(submodule.config._attn_implementation == "eager") + @require_torch_sdpa def test_sdpa_can_dispatch_non_composite_models(self): """ From 86d56bb52e4533f8c719ce20ebd605ecc50e5e82 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 14 Aug 2025 11:01:12 +0200 Subject: [PATCH 2/3] use non-explicit `None` --- src/transformers/configuration_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 8bec1e3c1e20..9210440558b1 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -410,14 +410,14 @@ def _attn_implementation(self): def _attn_implementation(self, value: Optional[Union[str, dict]]): """We set it recursively on the sub-configs as well""" # Set if for current config - attn_implementation = value if not isinstance(value, dict) else value.get("", None) + attn_implementation = value if not isinstance(value, dict) else value.get("") self._attn_implementation_internal = attn_implementation # Set it recursively on the subconfigs for subconfig_key in self.sub_configs: subconfig = getattr(self, subconfig_key, None) if subconfig is not None: - sub_implementation = value if not isinstance(value, dict) else value.get(subconfig_key, None) + sub_implementation = value if not isinstance(value, dict) else value.get(subconfig_key) subconfig._attn_implementation = sub_implementation def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): From fd296439db495ce36b61b548d1a923e7e5e9545f Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 19 Aug 2025 11:22:48 +0200 Subject: [PATCH 3/3] keep previously set attn if exists --- src/transformers/configuration_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 9210440558b1..f07ed8565bd5 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -410,14 +410,18 @@ def _attn_implementation(self): def _attn_implementation(self, value: Optional[Union[str, dict]]): """We set it recursively on the sub-configs as well""" # Set if for current config - attn_implementation = value if not isinstance(value, dict) else value.get("") + current_attn = getattr(self, "_attn_implementation", None) + attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn) self._attn_implementation_internal = attn_implementation # Set it recursively on the subconfigs for subconfig_key in self.sub_configs: subconfig = getattr(self, subconfig_key, None) if subconfig is not None: - sub_implementation = value if not isinstance(value, dict) else value.get(subconfig_key) + current_subconfig_attn = getattr(subconfig, "_attn_implementation", None) + sub_implementation = ( + value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn) + ) subconfig._attn_implementation = sub_implementation def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):