Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,15 +410,17 @@ def _attn_implementation(self):
def _attn_implementation(self, value: Optional[Union[str, dict]]):
"""We set it recursively on the sub-configs as well"""
# Set if for current config
attn_implementation = value if not isinstance(value, dict) else value.get("", self._attn_implementation)
current_attn = getattr(self, "_attn_implementation", None)
attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
self._attn_implementation_internal = attn_implementation

# Set it recursively on the subconfigs
for subconfig_key in self.sub_configs:
subconfig = getattr(self, subconfig_key, None)
if subconfig is not None:
current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
sub_implementation = (
value if not isinstance(value, dict) else value.get(subconfig_key, subconfig._attn_implementation)
value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
)
subconfig._attn_implementation = sub_implementation

Expand Down
14 changes: 14 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3680,6 +3680,20 @@ def test_attn_implementation_composite_models(self):
model = model_class(config)
self.assertTrue(model.config.get_text_config(decoder=True)._attn_implementation == "eager")

# Test that using `dict` atttention implementation works with `from_pretrained`
# Set all backbones to "eager" because "eager" attention is always available
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = model.from_pretrained(tmpdirname, attn_implementation=attn_implementation_per_subconfig)
self.assertTrue(new_model.config._attn_implementation == "eager")
for submodule in new_model.modules():
if (
submodule is not new_model
and isinstance(submodule, PreTrainedModel)
and submodule.config.__class__ != new_model.config.__class__
):
self.assertTrue(submodule.config._attn_implementation == "eager")

@require_torch_sdpa
def test_sdpa_can_dispatch_non_composite_models(self):
"""
Expand Down