Default auto 🚨 🚨 #42805
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Very happy so break this and have auto by default!!
| dtype = get_state_dict_dtype(state_dict) | ||
| dtype = get_state_dict_dtype(state_dict, getattr(config, "dtype", None)) |
There was a problem hiding this comment.
Why do we need to change this? They are both already in branches when we know that config.dtype does not exist anyway
|
|
||
|
|
||
| def get_state_dict_dtype(state_dict): | ||
| def get_state_dict_dtype(state_dict, config_dtype: Optional[torch.dtype] = None): |
There was a problem hiding this comment.
I don't think this function needs to be changed, see previous comment
| if getattr(self.config, "dtype", None) is None: | ||
| default_dtype = torch.get_default_dtype() | ||
| self.config.dtype = default_dtype | ||
| for sub_config_key in self.config.sub_configs: | ||
| if (sub_config := getattr(self.config, sub_config_key)) is not None and getattr( | ||
| sub_config, "dtype", None | ||
| ) is None: | ||
| sub_config.dtype = default_dtype |
There was a problem hiding this comment.
Do we need to write it at __init__? In any case, no need to do it on all subconfigs, as all submodels run __init__ during the whole __init__ process (so writing it only to self.config is enough, they will each do it)
|
[For maintainers] Suggested jobs to run (before merge) run-slow: beit, bigbird_pegasus, blip_2, data2vec, edgetam, gpt_oss, internvl, sam2, sam3, timm_wrapper |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Let's gooooo the nice defaults 🚀🚨
|
I understand that this is a big change that is intended to break some stuff. E.g. for opt-125m, the auto dtype is now float16 instead of float32. However, this PR also makes it so that the import torch
from transformers import AutoModelForCausalLM
model_id = "facebook/opt-125m"
# previous commit (8a2a83d574fd461697a29410a36737ed112f8ba7)
# this passes
model = AutoModelForCausalLM.from_pretrained(model_id)
assert model.dtype == torch.float32
model.half()
assert {p.dtype for p in model.parameters()} == {torch.float16}
assert model.dtype == torch.float16, f"not float16, got {model.dtype} instead" # passes
# after this commit (6217adc6c8f0be7b5374e6a46129ad2214e4c6ed)
model = AutoModelForCausalLM.from_pretrained(model_id)
assert model.dtype == torch.float16 # <= used to be float32
model.float()
assert {p.dtype for p in model.parameters()} == {torch.float32}
assert model.dtype == torch.float32, f"not float32, got {model.dtype} instead" # fails
# AssertionError: not float32, got torch.float16 instead |
|
Ha damn, I just merged #42825 but seeing this now... Will open a new one, it indeed looks like relying on the config is a bad idea in those cases! |
* default to `"auto"` dtype * the actual change? * up? * style * up? * only sam models were broken with this * fix sams * update * fix sam2 now * up * this? * proper fix * lol * fix * fixes * nit * fix * fix copies * fixes * fix bigbird * revert one bit
What does this PR do?
Superseeds #34919