Enable different torch dtype in sub models#34873
Enable different torch dtype in sub models#34873zucchini-nlp merged 10 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| for sub_config_key in config.sub_configs.keys(): | ||
| sub_config = getattr(config, sub_config_key) | ||
| sub_config.torch_dtype = torch_dtype | ||
| elif isinstance(torch_dtype, dict): | ||
| for key, curr_dtype in torch_dtype.items(): | ||
| if hasattr(config, key): | ||
| value = getattr(config, key) | ||
| value.torch_dtype = curr_dtype |
There was a problem hiding this comment.
if users passes one torch dtype as before, we just use it in all sub-configs. Otherwise a user can either set directly dtypes in configs before loading the model, or indicate a dict torch_dtype when loading similarly to attn_implementation_dispatch
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks, can you add a test showcasing an example usage of this (for example with Llava!)
ArthurZucker
left a comment
There was a problem hiding this comment.
This is interesting but I am not sure we have everything ready:
- does it work with the
keep-in-float32attribute as well? - does it work well with model that have enforced param with dtypes?
(Some vision models have this! ) So maybe a little bit of testing is missing
Makes a lot of sense otherwise!
|
Yeah, this one needs time and I'll come back after the model releases to make sure it works in all cases. Currently it has weird behavior in nested configs where a general text config has an attribute |
|
Conditions to make the dtype dispatch correctly: use The current design support setting dtype via Added more tests and verified it works when |
ArthurZucker
left a comment
There was a problem hiding this comment.
There seems to be a big breaking change no?
|
Not breaking anymore, will need at least one approve to merge :) |
ArthurZucker
left a comment
There was a problem hiding this comment.
Good to go but I'd rather wait for the next release 🤗
|
cool, since release is done will rebase and merge |
What does this PR do?
Fixes #33997. Enables users to use different torch dtypes for each of sub config. For ex load the vision model in full precision and the text model in half precision