Skip to content

Prevent Reinitialization of Resized LM Head When tie_word_embeddings is False #35141#36221

Open
sambhavnoobcoder wants to merge 12 commits intohuggingface:mainfrom
sambhavnoobcoder:output-embedding-reinitilaization
Open

Prevent Reinitialization of Resized LM Head When tie_word_embeddings is False #35141#36221
sambhavnoobcoder wants to merge 12 commits intohuggingface:mainfrom
sambhavnoobcoder:output-embedding-reinitilaization

Conversation

@sambhavnoobcoder
Copy link
Copy Markdown
Contributor

Issue Description

When using models with tie_word_embeddings=False, calling resize_token_embeddings() followed by post_init() causes unintended reinitialization of the output embeddings (LM head). This occurs because the newly created LM head during resizing lacks the _is_hf_initialized flag, causing post_init() to treat it as uninitialized and reinitialize its weights.

Solution

Added _is_hf_initialized = True flag to the new LM head in _get_resized_lm_head(). This ensures that post_init() recognizes the module as already initialized and skips reinitialization, preserving the intended weights.

The change is minimal and targeted:

  • Only affects cases where tie_word_embeddings=False
  • Maintains backward compatibility
  • Preserves existing initialization behavior for new tokens

Test Coverage

Core Test: test_model_resize_embeddings.py

test_resize_embeddings_no_reinit

This test verifies:

  1. Takes initial snapshot of LM head weights
  2. Resizes embeddings (+10 tokens)
  3. Verifies original weights preserved after resize
  4. Calls post_init()
  5. Verifies original weights still match initial snapshot
  • Uses torch.allclose() for exact weight comparison

test_new_tokens_initialization

This test verifies:

  1. Resizes vocabulary (+10 tokens)
  2. Examines only the new token weights
  3. Verifies that new weights:
    • Are not all zeros
    • Stay within reasonable bounds (abs value < 100)

test_resize_embeddings_with_bias

This test verifies:

  1. Creates model with tie_word_embeddings=False
  2. Adds bias to LM head (which is not present by default)
  3. Takes snapshots of both weights and bias
  4. Resizes embeddings (+10 tokens)
  5. Calls post_init()
  6. Verifies:
    • Original weights preserved in resized LM head
    • Original bias values preserved in resized LM head
    • Uses torch.allclose() for exact comparison of both weights and bias

Test Results

All tests pass successfully, confirming:

  • Original token embeddings remain unchanged through resize and post_init
  • New token embeddings are properly initialized
  • No regressions in existing functionality
Screenshot 2025-02-17 at 2 24 13 AM

Implementation Details

The fix is implemented in src/transformers/modeling_utils.py, adding a single line to mark the new LM head as initialized immediately after creation.

Backwards Compatibility

This change:

  • Does not modify existing APIs
  • Maintains current behavior for tied embeddings
  • Only affects the internal initialization state of resized LM heads

Fixes : #35141

cc: @ArthurZucker @Rocketknight1

@Rocketknight1
Copy link
Copy Markdown
Member

Hi @sambhavnoobcoder, sorry for missing this one earlier! The solution seems good, but can we move the test into an existing file rather than a separate file for just that test? I think it might fit in the test_tokenization_common.py file, although you'll need to remove the lines referring to a specific model.

@sambhavnoobcoder sambhavnoobcoder force-pushed the output-embedding-reinitilaization branch from ca71974 to fc7c069 Compare April 23, 2025 07:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

resizing token embeddings causes output embedding to be reinitialized in post_init when tie_word_embedding is False

2 participants