Missing weights not initialized properly #35437#35913
Missing weights not initialized properly #35437#35913sambhavnoobcoder wants to merge 17 commits intohuggingface:mainfrom
Conversation
|
Hi @sambhavnoobcoder, I see some failing tests in the CI! I think the cause is that in some cases, models have tied weights, meaning that the input embeddings and output projection are identical. In these cases, only one of those tensors may exist in the |
|
@Rocketknight1 Thank you for pointing to the issue with tied weights. I've modified the initialization logic to detect tied parameters using |
|
Hi @sambhavnoobcoder debugging issues like this can be tricky. I suggest the following approach:
This is quite advanced debugging, but it's unfortunately necessary sometimes. Good luck! |
ArthurZucker
left a comment
There was a problem hiding this comment.
This is important IMO! but does not need a special additionnal file test!
I don't have my head in this part of the code so @Cyrilvallez if you want to have a look! 🤗
| ) | ||
|
|
||
| # After loading weights, initialize missing ones properly | ||
| missing_keys = set(model.state_dict().keys()) - set(loaded_keys) |
There was a problem hiding this comment.
why don't we use the missing_keys defined above?
There was a problem hiding this comment.
okay , addressed this in 566028f commit .
|
as for tests in seperate file , added the seperate tests now in test_modeling_utils.py file only in commit 54dfcb4 . No changes required to thte tests after refactor . from transformers import AutoModel, AutoConfig
import torch
def test_real_model_initialization():
"""Test initialization with a real model by adding a new classification layer"""
# Create base model and save
model = AutoModel.from_pretrained("bert-base-uncased")
model.save_pretrained("./test-model")
# Modify config to add new classification layer
config = AutoConfig.from_pretrained("./test-model")
class BertWithClassification(type(model)):
def __init__(self, config):
super().__init__(config)
# Add a new classification layer
self.classifier = torch.nn.Linear(config.hidden_size, 3) # 3 classes
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
# Load with new architecture - should initialize new layer properly
new_model = BertWithClassification.from_pretrained("./test-model", config=config)
# Verify no NaN values in new layer
assert not torch.isnan(new_model.classifier.weight).any(), "NaN found in classifier weights"
assert not torch.isnan(new_model.classifier.bias).any(), "NaN found in classifier bias"
# Verify base model weights are preserved
for name, param in new_model.named_parameters():
if "classifier" not in name: # Skip the new layer
orig_param = model.get_parameter(name)
assert torch.equal(param, orig_param), f"Original weights not preserved for {name}"
print("Real model test passed successfully!")
return new_model
if __name__ == "__main__":
test_real_model_initialization()Thanks for the reviews . i'll make any other changes required as well @ArthurZucker @Cyrilvallez |
|
Hey @sambhavnoobcoder! The logic of |
|
Hey @Cyrilvallez , |
|
btw , i saw the large refactor , and the changes coming are truly awesome and much needed . however i think since you have a better perspective of the changes coming up than me , i would appreciate if you could look at my PR and tell me if i need to make any more changes to it to make it future proof to the upcoming changes . |
|
Hey, sorry for the late reply, we are actively working on improving the loading logic these days! |
|
Any update on this PR? Still having issues with fresh weights added to existing modules not initialising correctly. Only solution was to revert to an older transformers version, as disabling fast init does not fix things. |
|
Hey @Avelina9X! I can look into it if you provide a min repro of the issue. BTW, |
|
HI @Cyrilvallez, thanks for the response but I managed to fix the issue. My use case is VLM training, so I have a composite model where I have also injected additional layer types into the base language model. Turns out the new "smart apply" weight init system was at fault here, since the weight inits of the outer VLM get overwritten by the inner language model, meaning the new layers had no valid initialisation scheme in the newer transformer versions. I fixed this by simply overriding the For reference to anyone that might come across a similar problem I solved the issue by adding this in my VLM's pretrained model class: @torch.no_grad
def initialize_weights( self ):
# Ugly hack, but prevents the new smart apply system from breaking composite model inits
self.apply( self._initialize_weights )I do realise I should ideally change the |
|
Closing this since that logic has evolved considerably since the original issue, and it's not relevant anymore! Feel free to reopen if you think there is still something to address! 🤗 |
@Cyrilvallez You might face that when we use the custom number of classes with id2label the initialization of classification head does not properly initialized (e.g. 1e34) (even though deta is deprecated this is just code snippet) cc. @qubvel @yonigozlan |
Problem Statement
When using
from_pretrained()to load a model with new parameters that weren't in the original saved model, these new parameters were not being properly initialized according to the model's_init_weights()method. Instead, they remained in their default PyTorch initialization state, sometimes resulting in NaN values.Root Cause Analysis
The issue was identified in the
_load_pretrained_modelmethod where missing weights weren't being properly initialized when_fast_init=True(default behavior). This caused inconsistent behavior between direct model initialization and loading viafrom_pretrained().Solution
Modified the
_load_pretrained_modelmethod to properly initialize missing weights using the model's_init_weights()method, regardless of the_fast_initsetting. The solution maintains backward compatibility while ensuring consistent initialization behavior.Implementation Details
Testing Strategy
Created comprehensive test suite verifying:
_fast_initTest Results
Related Issues