diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 055688f27be8..442b19a7c9aa 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -410,15 +410,17 @@ def _attn_implementation(self): def _attn_implementation(self, value: Optional[Union[str, dict]]): """We set it recursively on the sub-configs as well""" # Set if for current config - attn_implementation = value if not isinstance(value, dict) else value.get("", self._attn_implementation) + current_attn = getattr(self, "_attn_implementation", None) + attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn) self._attn_implementation_internal = attn_implementation # Set it recursively on the subconfigs for subconfig_key in self.sub_configs: subconfig = getattr(self, subconfig_key, None) if subconfig is not None: + current_subconfig_attn = getattr(subconfig, "_attn_implementation", None) sub_implementation = ( - value if not isinstance(value, dict) else value.get(subconfig_key, subconfig._attn_implementation) + value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn) ) subconfig._attn_implementation = sub_implementation diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 678f5173f669..fff4bb896312 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3680,6 +3680,20 @@ def test_attn_implementation_composite_models(self): model = model_class(config) self.assertTrue(model.config.get_text_config(decoder=True)._attn_implementation == "eager") + # Test that using `dict` atttention implementation works with `from_pretrained` + # Set all backbones to "eager" because "eager" attention is always available + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = model.from_pretrained(tmpdirname, attn_implementation=attn_implementation_per_subconfig) + self.assertTrue(new_model.config._attn_implementation == "eager") + for submodule in new_model.modules(): + if ( + submodule is not new_model + and isinstance(submodule, PreTrainedModel) + and submodule.config.__class__ != new_model.config.__class__ + ): + self.assertTrue(submodule.config._attn_implementation == "eager") + @require_torch_sdpa def test_sdpa_can_dispatch_non_composite_models(self): """