-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[From pretrained] Speed-up loading from cache #2515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7c3c3e1
d1ad4d3
19a0fdf
300544a
152f902
d13141a
c4a49e6
513b213
c4aadde
b43be19
22eeb11
a37cb95
0f39ab7
d6a1815
79afaf2
e4bff0b
30717b0
71fa6b8
a07ed0f
5f2472e
63bf2f8
6aeb8e3
aabdde8
d085f06
26aacc0
4590c99
b569eb8
3470424
d28e8d4
0359a96
343a330
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -458,18 +458,34 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| " dispatching. Please make sure to set `low_cpu_mem_usage=True`." | ||
| ) | ||
|
|
||
| # Load config if we don't provide a configuration | ||
| config_path = pretrained_model_name_or_path | ||
|
|
||
| user_agent = { | ||
| "diffusers": __version__, | ||
| "file_type": "model", | ||
| "framework": "pytorch", | ||
| } | ||
|
|
||
| # Load config if we don't provide a configuration | ||
| config_path = pretrained_model_name_or_path | ||
|
|
||
| # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the | ||
| # Load model | ||
|
|
||
| # load config | ||
| config, unused_kwargs, commit_hash = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| return_commit_hash=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| user_agent=user_agent, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # load model | ||
| model_file = None | ||
| if from_flax: | ||
| model_file = _get_model_file( | ||
|
|
@@ -484,20 +500,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| ) | ||
| config, unused_kwargs = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| **kwargs, | ||
| commit_hash=commit_hash, | ||
| ) | ||
| model = cls.from_config(config, **unused_kwargs) | ||
|
|
||
|
|
@@ -520,6 +523,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| commit_hash=commit_hash, | ||
| ) | ||
| except: # noqa: E722 | ||
| pass | ||
|
|
@@ -536,25 +540,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| commit_hash=commit_hash, | ||
| ) | ||
|
|
||
| if low_cpu_mem_usage: | ||
| # Instantiate model with empty weights | ||
| with accelerate.init_empty_weights(): | ||
| config, unused_kwargs = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| **kwargs, | ||
| ) | ||
| model = cls.from_config(config, **unused_kwargs) | ||
|
|
||
| # if device_map is None, load the state dict and move the params from meta device to the cpu | ||
|
|
@@ -593,20 +584,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| "error_msgs": [], | ||
| } | ||
| else: | ||
| config, unused_kwargs = cls.load_config( | ||
| config_path, | ||
| cache_dir=cache_dir, | ||
| return_unused_kwargs=True, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| device_map=device_map, | ||
| **kwargs, | ||
| ) | ||
| model = cls.from_config(config, **unused_kwargs) | ||
|
|
||
| state_dict = load_state_dict(model_file, variant=variant) | ||
|
|
@@ -803,6 +780,7 @@ def _get_model_file( | |
| use_auth_token, | ||
| user_agent, | ||
| revision, | ||
| commit_hash=None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the Another possibility is to pass only
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still need |
||
| ): | ||
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||
| if os.path.isfile(pretrained_model_name_or_path): | ||
|
|
@@ -840,7 +818,7 @@ def _get_model_file( | |
| use_auth_token=use_auth_token, | ||
| user_agent=user_agent, | ||
| subfolder=subfolder, | ||
| revision=revision, | ||
| revision=revision or commit_hash, | ||
| ) | ||
| warnings.warn( | ||
| f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", | ||
|
|
@@ -865,7 +843,7 @@ def _get_model_file( | |
| use_auth_token=use_auth_token, | ||
| user_agent=user_agent, | ||
| subfolder=subfolder, | ||
| revision=revision, | ||
| revision=revision or commit_hash, | ||
| ) | ||
| return model_file | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This testing library is super useful to check how many HEAD, GET requests were made, is popular and very lightweight, so think we can add it here to help with testing.