Skip to content

[performance] from_pretrained is still much slower than torch.load and seems to be initializing weights #21913

@moyix

Description

@moyix

System Info

  • transformers version: 4.26.1
  • Platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.31
  • Python version: 3.10.9
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 2.0.0.dev20230224+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@stas00, @patrickvonplaten

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Loading a model with from_pretrained takes much longer than the underlying torch.load. For example, for the Salesforce/codegen-6B-mono model, CodeGenForCausalLM.from_pretrained('Salesforce/codegen-6B-mono') takes ~38 seconds, whereas torch.load() on its pytorch_model.bin takes just ~5.4 seconds. This is very similar to #9205, but is happening with the latest transformers from pip (4.26.1), so possibly a regression?

Short repro:

import time
import torch
from transformers import CodeGenForCausalLM
t1 = time.time()
CodeGenForCausalLM.from_pretrained('Salesforce/codegen-6B-mono')
t2 = time.time()
print("Load took", t2-t1, "seconds")

Prints Load took 37.78910255432129 seconds

import time
import torch
from transformers.utils import cached_file
torch.load(cached_file('Salesforce/codegen-6B-mono', 'pytorch_model.bin'))

Prints Load took 5.443041801452637 seconds

Based on profiling the HF from_pretrained script, it seems like ~75% of the time is being spent doing random initialization of weights that are about to be overwritten. This is the same problem that was fixed in PR #11471 so I'm not sure what's going on here.

Here's the cProfile output and output from gprof2dot:
loadmodel_profile.txt
hf_loadmodel_new.pdf

Expected behavior

from_pretrained should skip weight initialization when loading a pretrained model.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions