fix FSDP loading with meta devices#44473
Conversation
|
This PR basically restores the logic from v4.57.6 per https://github.com/huggingface/transformers/blob/v4.57.6/src/transformers/modeling_utils.py#L5853-L5877 (this was removed in #41580 https://github.com/huggingface/transformers/pull/41580/changes#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL5110-L5142) |
Cyrilvallez
left a comment
There was a problem hiding this comment.
I believe we should just completely skip the init in this case rather than mark everything as initialized, then try to initialize??
| # Handle FSDP edge case when using cpu ram efficient loading to ensure it is marked as initialized | ||
| # since it will get its weights broadcasted from rank0 | ||
| for key in self.state_dict(): | ||
| try: | ||
| param_or_buffer = self.get_parameter_or_buffer(key) | ||
| param_or_buffer._is_hf_initialized = True | ||
| except AttributeError: | ||
| pass # may happen when handling pre-quantized weights | ||
| self._is_hf_initialized = True |
There was a problem hiding this comment.
Should we simply return here instead, to completely avoid calling initialize_weights later in the function? Would be easier than setting all weights as initialized before calling initialize which will be skipped anyway as params are marked as already initialized
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
LGTM! Thanks a lot for this!
For posterity, the issue is only the non-persistent buffers, as they are NOT gathered from rank0 (only the state_dict), so we need to go through all inits for them (but skipping everything in the state_dict)!

What does this PR do?
supersedes #44446
on
main, when loading to cpu and using meta devices for non-rank0 processes, it now re-initializes weights on those processes as well as uses more CPU memory. In testing with loading llama3-8b.main; both on CPU, uses 16GB system RAM, slow to load, re-inits weights on rank1
#44446: rank0 on CPU, rank1 on meta, uses 1.5GB system RAM
v4.57.6, both on CPU, uses 1.5GB system RAM
this PR, both on CPU, uses 1.5GB system RAM, same behavior and training loss as main and v4.57.6
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.