Skip to content

Fix FSDP_CPU_RAM_EFFICIENT_LOADING (#43749)#43785

Open
MengAiDev wants to merge 3 commits intohuggingface:mainfrom
MengAiDev:fix/fsdp-cpu-ram-efficient-loading
Open

Fix FSDP_CPU_RAM_EFFICIENT_LOADING (#43749)#43785
MengAiDev wants to merge 3 commits intohuggingface:mainfrom
MengAiDev:fix/fsdp-cpu-ram-efficient-loading

Conversation

@MengAiDev
Copy link
Copy Markdown
Contributor

  • Add _is_hf_initialized flag in _load_parameter_into_model to prevent unnecessary random initialization
  • Skip state_dict loading for non-rank0 processes when FSDP is enabled to avoid wasting CPU RAM
  • This fixes the issue where all ranks temporarily allocate model-sized CPU RAM and experience long delays

Modified:

  • src/transformers/modeling_utils.py: Set _is_hf_initialized=True on new parameters
  • src/transformers/core_model_loading.py: Add FSDP check to skip loading on non-rank0

Fixes #43749

@Cyrilvallez @ArthurZucker

- Add _is_hf_initialized flag in _load_parameter_into_model to prevent unnecessary random initialization
- Skip state_dict loading for non-rank0 processes when FSDP is enabled to avoid wasting CPU RAM
- This fixes the issue where all ranks temporarily allocate model-sized CPU RAM and experience long delays

Modified:
- src/transformers/modeling_utils.py: Set _is_hf_initialized=True on new parameters
- src/transformers/core_model_loading.py: Add FSDP check to skip loading on non-rank0
Comment thread src/transformers/core_model_loading.py Outdated
Comment on lines +1109 to +1114
from .integrations import is_fsdp_enabled
from .modeling_utils import is_local_dist_rank_0

if is_fsdp_enabled() and not is_local_dist_rank_0() and hf_quantizer is None:
state_dict = []

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is where the skip should happen, you r are creating a thread pool for notthing

@@ -476,6 +476,10 @@ def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor
parent, param_type = get_module_from_name(model, param_name)
if param_type in parent._parameters and not isinstance(tensor, nn.Parameter):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only used in _move_missing_keys_from_meta_to_device for missing keys, it does not really make sense for me can you elaborate on why?

@winglian
Copy link
Copy Markdown
Collaborator

winglian commented Feb 6, 2026

I don't think this behavior should happen at this low of a level. There are a whole host of other cases where this would break ND-parallel loading (HSDP for example), where you would have FSDP, but have multiple data parallel meshes that each would need the model weights too.

- Move FSDP check to function entry point to avoid creating empty thread pool
- Add detailed documentation for _is_hf_initialized flag usage
- Add should_skip_non_rank0_weight_loading() helper to support HYBRID_SHARD strategies
- Ensure compatibility with ND-parallel scenarios like HSDP

This addresses review feedback:
- @ArthurZucker: Fix thread pool creation issue and explain _is_hf_initialized purpose
- @winglian: Ensure compatibility with HYBRID_SHARD and HYBRID_SHARD_ZERO2 strategies
@Cyrilvallez
Copy link
Copy Markdown
Member

Agreed, everything should simply be skipped both in _initialize_missing_keys

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FSDP_CPU_RAM_EFFICIENT_LOADING broken

4 participants