[model loading] don't init weights for pretrained models#11463
[model loading] don't init weights for pretrained models#11463stas00 wants to merge 2 commits intohuggingface:masterfrom
Conversation
sgugger
left a comment
There was a problem hiding this comment.
This won't work sadly, for two reasons.
- First, everything in the
__dict__of the config gets serialized when one usesconfig.self_pretrained()(which is called bymodel.from_pretrained) so any other model downloaded from the hub with a checkpoint saved after this is merged will get this attribute in the config. Then if a user instantiates a randomly-initialized model using the config, with the following code:
config = AutoConfig.from_pretrained("new_checkpoint_after_this_is_merge")
model = AutoModel.from_config(config)then the model won't be randomly initalized (at least not with _init_weights) since the config will have this use_pretrained_weights.
- Then come the problem that pretrained model instantiated with
from_pretraineddoes not necessarily have all weights initialized (if you discard the head to put another task-specific head) and this PR will break the way those weights are randomly initialized.
I sadly don't see a way around passing around a list of not-initialized weights from pretrained to the _init_weights function
So if I find another way to do it that doesn't taint the config then it's OK, right? (as far as config correctness goes) e.g. what if I unset this config value as soon as
I appreciate that you could think of the edge cases. Clearly, we don't have any tests that somehow verify that the init is done correctly. I was hoping that there would be some, but these would be hard to conjure. If you feel this is a worthwhile effort, perhaps let's start coming up with examples, write tests if possible and solve those? You can throw the edge-cases at me and I will try to overcome those. Or alternatively, we provide a very easy way for users to either force the init, or if it's safer to force no-init? e.g. the staple examples could all enforce no-init and explain how to change that if the user wants to modify the example to have the original behavior? So what I'm suggesting is that instead of |
| """ | ||
| # Initialize weights | ||
| self.apply(self._init_weights) | ||
| if getattr(self.config, "use_pretrained_weights", False): |
There was a problem hiding this comment.
-
I'm not a huge fan of "attaching" a new parameter to the
configwhich is not really understandable by the user. -
Also, I think this could lead to problems -> lots of people initialize all weights except the final layer weights from a pre-trained BERT in, e.g. a
BertForSequenceClassification. The logic would then not correctly initialize the final layer, but simply set everything to zero which would probably lead to a worse fine-tuning ofBertForSequenceClassification.
=> I would propose the following:
- When using
from_pretrained(...), we pass a new parameter tomodel = cls(config, *model_args, **model_kwargs)by settingmodel_kwargs["init_weights"] = False. This then sadly means that we have to replace all__init__(self, config)functions in the modeling files by__init__(self, config, init_weights=True), but I think we can use a regex for this. This is a huge change in terms of files that need to be changed, but I think it's cleaner then creating a new"use_pretrained_weights"config parameter that the user shouldn't have to learn about. Then, we also need to changeself.init_weights()with
if init_weights:
self.init_weights()- Now, we also need to take care of cases where only parts of the model are initialized from pre-trained weights. The other part still needs to be initialized. Here we can't simply use
self.init_weights()because it would necessarly run through all modules and initialized them. So I think we should leverage themissing_keys()list here to extract allnn.Modules(...)that still need to be initialized and then runself._init_weights(m) for m in uninitialized_modules
There was a problem hiding this comment.
What do you think ? @stas00
Also keen to hear @LysandreJik's and @sgugger's opinion here
There was a problem hiding this comment.
The init_weights kwarg by itself will not work as it doesn't deal with 2. As I said in my comment, the only one to properly deal with this is to pass an uninitalized_weights kwargs (as done by the missing_keys) which would then be used:
if len(uninitalized_weights) > 0:
self.init_weights(uninitalized_weights)
and of course init_weights then needs to use a function different than apply that only applies _init_weights to the unitialized_weights.
There was a problem hiding this comment.
The init_weights kwarg on its own won't work, but it's necessary to prevent each model from calling self.init_weights().
The order of operations when doing BertModel.from_pretrained(...) is the following:
-
Instantiate a random model:
cls(config, *model_args, **model_kwargs)=> this command already callsself.init_weights(...)(since in every model class we have aself.init_weights(...)in__init__(config):. So in order to prevent this we need to pass a flag tocls(config, *model_args, **model_kwargs)which I would do withmodel_kwargs["init_weights"] = False`. -
Only after the model is instantiated (and the weights already have values), we can know which weights were missing & thus need to be randomely initialized. Here we can retrieve
uninitialized_weights, but it would be better to actually retrieve allnn.Modulesthat are randomely initialized since then we can reuse each model's_init_weights(...)function. -
Having retrieved
uninitialized_moduleswe can runself._init_weights(...)on each module.
That is not the edge case but the overwhelming majority ;-) You are mostly working with seq2seq models that don't throw away any weights when doing transfer learning, but all the basic examples fine-tuning BERT on a classification task encounter this :-) Testing the init is done properly is very difficult as those are all random weights. Testing those weights follow this distribution instead of that one is not something easily achievable. I don't think the |
|
OK, let's move the effort to Patrick's PR #11471
Guilty as charged. I'm glad you guys have a much wider view than I. Thank you! |
Skip
_init_weightsfor pretrained models since they get immediately replaced by pretrained weights. This leads to a much faster startup for huge models.Fixes: #9205
@sgugger, @patrickvonplaten