Skip to content

[loading] Really initialize on meta device for huge perf gains#42941

Merged
Cyrilvallez merged 24 commits intomainfrom
init-meta
Dec 19, 2025
Merged

[loading] Really initialize on meta device for huge perf gains#42941
Cyrilvallez merged 24 commits intomainfrom
init-meta

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Dec 18, 2025

What does this PR do?

Follow-up to #42309 to really leverage meta device loading. Gives crazy speedups for loading some models, e.g. about 2.5x on gpt-oss 20b and about 3x on the 120b version

The issue at hand

Currently, during loading we initialize the model on meta device before loading weights, thanks to init_empty_weights from accelerate.
However, this context manager has BIG drawbacks:

  • everything (every parameter and buffer) is first materialized on cpu before being moved to meta device
  • this is extremely inefficient of course, as we only want them on meta -> it wastes time and memory
  • all buffers stay on cpu (even the persistent ones, that we are loading again after anyway, so don't need to be there...)

For some models, e.g. gpt-oss, we have the following during loading:

Screenshot 2025-12-19 at 12 11 30

Note how most of the loading time is BEFORE the actual loading of the weights (_load_pretrained call), just to initialize parameters that should be on meta device anyway....

What this PR is doing

This PR completely removes init_empty_weights in favor of torch.device("meta") to really start with a model on meta device, without first putting them on cpu.
We are free to do so since I've merged #42309 yesterday, to correctly handle re-initialization of the non-persistent buffers which are put on meta device as well.

The performance gains are immense, the same benchmark as before for gpt-oss shows it:

Screenshot 2025-12-19 at 12 16 00

and we can now see that most of the time in from_pretrained is used for the actual weight loading (_load_pretrained call), as it should be.

Raw numbers for the following simple benchmark script (from which the above traces are from) on our cluster:

from transformers import AutoModelForCausalLM
import torch
import time
from viztracer import VizTracer

model_id = "openai/gpt-oss-20b"
device = 0

tracer = VizTracer()
tracer.start()

torch.cuda.synchronize()
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, dtype=torch.float16)
torch.cuda.synchronize()
dt = time.time() - t0
print(f"Took {dt:.2f} s")

tracer.stop()
tracer.save(output_file="../trace.json")

are the following:

BEFORE THIS PR -> ~9.1s
ON THIS PR -> ~3.7s

Which means a speedup of about 2.5x.

For the 120B gpt-oss version, we have:

BEFORE THIS PR -> ~32.4s
ON THIS PR -> ~10.7s

or a speedup of about 3x

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez changed the title Init meta [loading] Really initialize on meta device for huge perf gains Dec 19, 2025
Comment on lines -100 to +97
with init_empty_weights():
with torch.device("meta"):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @SunMarc for this change, do you know if this switch on all the quantizers's replace_with_xxx is fine? Basically, the only difference is if the quantized layer registers some buffers, they would now be on meta as well. I checked for bnb and there it seems to be alright at least (no buffers)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as long as it's not a non persistant buffer, it should be fine !

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An early Christmas gift for everyone

Comment on lines -20 to -21
if is_accelerate_available():
from accelerate import init_empty_weights
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💌

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: chameleon, codegen, ctrl, deepseek_vl, deepseek_vl_hybrid, emu3, eomt, fastspeech2_conformer, gemma, gemma2, gptj, idefics2, idefics3, janus, layoutlmv3, llava_next

@Cyrilvallez Cyrilvallez merged commit bb9357f into main Dec 19, 2025
26 checks passed
@Cyrilvallez Cyrilvallez deleted the init-meta branch December 19, 2025 13:43
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
…ngface#42941)

* use meta device directly

* style

* move back non-persistent

* fix

* make helper

* fix it

* use native param dtype

* make tensors buffers

* style

* fix

* oupsi

* add a test and fix

* fix

* create timm integration to reinit non-persistemnt buffers....

* style

* style

* more

* better

* add doc

* more timm stuff

* more

* fix

* small change

* no actually it was fine before
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants