From 75008c47b37458d5c0b5ee5c2a9806e43bf6784b Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Sat, 4 Apr 2026 15:53:09 +0900 Subject: [PATCH] deepcopy old_lm_head before changing input_embeddings --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f72b230d9a20..a5daa9ee5e29 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2662,6 +2662,7 @@ def resize_token_embeddings( def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True): old_embeddings = self.get_input_embeddings() + old_lm_head = copy.deepcopy(self.get_output_embeddings()) new_embeddings = self._get_resized_embeddings( old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing ) @@ -2684,8 +2685,7 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean new_num_tokens = new_embeddings.weight.shape[0] # if word embeddings are not tied, make sure that lm head is resized as well - if self.get_output_embeddings() is not None: - old_lm_head = self.get_output_embeddings() + if old_lm_head is not None: if isinstance(old_lm_head, torch.nn.Embedding): new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) else: