From 0c0efedee5f6d74dfe52c0ece6682458a1df6073 Mon Sep 17 00:00:00 2001 From: Tom Reichel Date: Tue, 23 Jan 2024 06:49:21 +0000 Subject: [PATCH 1/2] test that tied output embeddings aren't initialized on load --- tests/test_modeling_common.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b5189124a78b..fc904d0bc445 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -483,6 +483,40 @@ def _init_weights(self, module): max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])) self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical") + def test_fast_init_tied_embeddings(self): + class MyClass(PreTrainedModel): + config_class = PretrainedConfig + _tied_weights_keys = ["output_embeddings.weight"] + + def __init__(self, config=None): + super().__init__(config if config is not None else PretrainedConfig()) + self.input_embeddings = nn.Embedding(10, 10) + self.output_embeddings = nn.Linear(10, 10, bias=False) + self.tie_weights() + + def get_output_embeddings(self): + return self.output_embeddings + + def set_output_embeddings(self, output_embeddings): + self.output_embeddings = output_embeddings + + def get_input_embeddings(self): + return self.input_embeddings + + def set_input_embeddings(self, input_embeddings): + self.input_embeddings = input_embeddings + + def _init_weights(self, module): + if module is self.output_embeddings: + raise ValueError("unnecessarily initialized tied output embedding!") + + model = MyClass() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + # throws if it initializes the tied output_embeddings + MyClass.from_pretrained(tmpdirname) + def test_save_load_fast_init_to_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: From 28fe452bfe2e9c731783a305497199729303dd23 Mon Sep 17 00:00:00 2001 From: Tom Reichel Date: Fri, 22 Dec 2023 03:16:52 +0000 Subject: [PATCH 2/2] don't initialize the output embeddings if we're going to tie them to the input embeddings --- src/transformers/modeling_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 15855ceb7ee1..8a4fd6eaee4c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3746,6 +3746,11 @@ def _fix_key(key): else: _loaded_keys = loaded_keys not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) + # if we're about to tie the output embeds to the input embeds we don't need to init them + if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: + output_embeddings = model.get_output_embeddings() + if output_embeddings is not None: + output_embeddings._is_hf_initialized = True else: not_initialized_submodules = dict(model.named_modules()) # This will only initialize submodules that are not marked as initialized by the line above.