Skip to content

Refactor core_model_loading to support FSDP shard-on-read loading#44974

Open
3outeille wants to merge 3 commits intofsdp-vs-ddpfrom
fsdp-core-model-loading
Open

Refactor core_model_loading to support FSDP shard-on-read loading#44974
3outeille wants to merge 3 commits intofsdp-vs-ddpfrom
fsdp-core-model-loading

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Mar 24, 2026

TODO:

  • Saving seems to take a bit of time tho. Need investigation
  • Need to check if it works in 1D (FSDP or TP)and 2D (FSDP + TP).

Running the script from #44996

(env_pr-44974-fsdp-core-model-loading) ➜  pr-44974-fsdp-core-model-loading git:(pr-44974-fsdp-core-model-loading) ✗ torchrun --nproc_per_node=4 train_fsdp_tp.py 2>&1 | tee ref.txt
W0326 17:05:52.336000 1498148 torch/distributed/run.py:803] 
W0326 17:05:52.336000 1498148 torch/distributed/run.py:803] *****************************************
W0326 17:05:52.336000 1498148 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0326 17:05:52.336000 1498148 torch/distributed/run.py:803] *****************************************
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|██████████| 146/146 [00:00<00:00, 1015.33it/s]
Loading weights: 100%|██████████| 146/146 [00:00<00:00, 947.54it/s]
Loading weights: 100%|██████████| 146/146 [00:00<00:00, 888.82it/s]
Loading weights: 100%|██████████| 146/146 [00:00<00:00, 967.20it/s]
Step    0 | Loss: 12.9297
Step   10 | Loss: 6.8154
Step   20 | Loss: 6.2856
Step   30 | Loss: 6.5783
Step   40 | Loss: 6.1821

@3outeille 3outeille force-pushed the fsdp-core-model-loading branch from e5fc7eb to 4b2a921 Compare March 24, 2026 16:14
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

return alternation, src_group_to_glob, tgt_group_to_glob


def resolve_target_wildcards(source_pattern: str, target_pattern: str, source_key: str) -> str:
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review this part

model=model,
missing_keys=loading_info.missing_keys if loading_info else None,
)
if len(collected_tensors) > 1 and model is not None:
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review this part

# ref.shape is the DTensor global shape. For DTensor-based TP+FSDP,
# parallelize_module + fully_shard compose correctly, so ref.shape
# and ref.placements are already correct for the 2D DTensor.
fsdp_param = DTensor.from_local(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review this part

- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests
@3outeille 3outeille force-pushed the fsdp-core-model-loading branch from c567240 to c1dab9e Compare April 14, 2026 13:44
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44974&sha=21f056

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants