[loading] Really initialize on meta device for huge perf gains#42941
[loading] Really initialize on meta device for huge perf gains#42941Cyrilvallez merged 24 commits intomainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
2e4d93e to
d387bd3
Compare
| with init_empty_weights(): | ||
| with torch.device("meta"): |
There was a problem hiding this comment.
cc @SunMarc for this change, do you know if this switch on all the quantizers's replace_with_xxx is fine? Basically, the only difference is if the quantized layer registers some buffers, they would now be on meta as well. I checked for bnb and there it seems to be alright at least (no buffers)
There was a problem hiding this comment.
as long as it's not a non persistant buffer, it should be fine !
There was a problem hiding this comment.
ArthurZucker
left a comment
There was a problem hiding this comment.
An early Christmas gift for everyone
| if is_accelerate_available(): | ||
| from accelerate import init_empty_weights |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: chameleon, codegen, ctrl, deepseek_vl, deepseek_vl_hybrid, emu3, eomt, fastspeech2_conformer, gemma, gemma2, gptj, idefics2, idefics3, janus, layoutlmv3, llava_next |
…ngface#42941) * use meta device directly * style * move back non-persistent * fix * make helper * fix it * use native param dtype * make tensors buffers * style * fix * oupsi * add a test and fix * fix * create timm integration to reinit non-persistemnt buffers.... * style * style * more * better * add doc * more timm stuff * more * fix * small change * no actually it was fine before
What does this PR do?
Follow-up to #42309 to really leverage meta device loading. Gives crazy speedups for loading some models, e.g. about 2.5x on gpt-oss 20b and about 3x on the 120b version
The issue at hand
Currently, during loading we initialize the model on meta device before loading weights, thanks to
init_empty_weightsfromaccelerate.However, this context manager has BIG drawbacks:
For some models, e.g. gpt-oss, we have the following during loading:
Note how most of the loading time is BEFORE the actual loading of the weights (
_load_pretrainedcall), just to initialize parameters that should be on meta device anyway....What this PR is doing
This PR completely removes
init_empty_weightsin favor oftorch.device("meta")to really start with a model on meta device, without first putting them on cpu.We are free to do so since I've merged #42309 yesterday, to correctly handle re-initialization of the non-persistent buffers which are put on meta device as well.
The performance gains are immense, the same benchmark as before for gpt-oss shows it:
and we can now see that most of the time in
from_pretrainedis used for the actual weight loading (_load_pretrainedcall), as it should be.Raw numbers for the following simple benchmark script (from which the above traces are from) on our cluster:
are the following:
BEFORE THIS PR -> ~9.1s
ON THIS PR -> ~3.7s
Which means a speedup of about 2.5x.
For the 120B gpt-oss version, we have:
BEFORE THIS PR -> ~32.4s
ON THIS PR -> ~10.7s
or a speedup of about 3x