Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks, left a couple of comments !
| @@ -0,0 +1,861 @@ | |||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |||
There was a problem hiding this comment.
let's create a new folder called tests/training and put those there instead. It will be better I think
| #!/bin/bash | ||
|
|
||
| # Script to run all FSDP mixin tests for dense models in parallel. | ||
| # Work in tandem with a special test_fsdp_mixin.py that batches all 11 distributed tests in a single mp.spawn. (will not be committed) | ||
| # Uses concurrency-limited dispatch: multiple models share GPU pairs since test models are tiny (~7 MiB). |
There was a problem hiding this comment.
even the fsdp folder we can move that there
| fsdp_plan: | ||
| Explicit FSDP config dict with a required "mode" key. | ||
|
|
||
| Auto mode: | ||
| fsdp_plan = {"mode": "auto"} | ||
|
|
||
| Auto mode with optional policies: | ||
| fsdp_plan = {"mode": "auto", "cpu_offload": False, "mixed_precision": True} | ||
|
|
||
| Manual mode: | ||
| fsdp_plan = { | ||
| "mode": "manual", | ||
| "modules": { | ||
| "model.embed_tokens": ["free_full_weight"], | ||
| "model.layers.0.self_attn": ["free_full_weight", "cpu_offload", "mixed_precision"], | ||
| "model.layers.0.mlp": ["free_full_weight"], | ||
| "model.norm": ["keep_full_weight"], | ||
| "lm_head": ["keep_full_weight"], | ||
| }, | ||
| } | ||
| """ |
There was a problem hiding this comment.
maybe it could make sense to have a nice dataclass for fsdp_plan instead of a dict ?
There was a problem hiding this comment.
Yep, I think the idea was to stay simple like the pp / tp plan. But for fsdp we might want more control.
| def apply_fsdp2( | ||
| model, | ||
| device_mesh, | ||
| fsdp_plan: dict[str, Any] | None, | ||
| ): | ||
| """ |
There was a problem hiding this comment.
It would be nice to check how it integrates well with trainer. We can pass fsdp and fsdp_config in training_args and we would have to do the mapping to create the correct fsdp_plan + we need to call apply_fsdp2 instead of prepare() on the model.
| from packaging import version | ||
|
|
||
| #TODO(3outeille): guarding to protect against missing import | ||
| from torch.distributed.device_mesh import init_device_mesh |
There was a problem hiding this comment.
Unconditional torch imports break non-torch environments
High Severity
Top-level import torch, import torch.distributed, import torch.multiprocessing, and from torch.distributed.device_mesh import init_device_mesh are added unconditionally at module level. This will cause an ImportError for anyone importing from testing_utils in an environment without PyTorch (or with an older PyTorch lacking device_mesh). The existing file already uses is_torch_available() guards for other torch imports — these new ones need the same treatment.
| ) | ||
| from .test_pipeline_mixin import PipelineTesterMixin | ||
| from .test_tensor_parallel_mixin import TensorParallelTesterMixin | ||
| from .test_fsdp_mixin import FSDPTesterMixin |
There was a problem hiding this comment.
Duplicate FSDPTesterMixin import in causal_lm_tester
Low Severity
from .test_fsdp_mixin import FSDPTesterMixin appears twice — once at line 32 and again at line 43. One of these is redundant and was likely left in by mistake during development.
Additional Locations (1)
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
ArthurZucker
left a comment
There was a problem hiding this comment.
Will look at the tests next time but my main comment:
You are not integrated with the core-model-loading API noo?
My understanding is that you should shard the weights exactly the same way we do for TP, maybe push further as you want to shard ALL layers, then when running the forward, each process materialize the weights locally (GP0-8 end up with the same full tensor) and discard it keeping only its slice.
We need 2 points:
- TENSOR_PARALLEL_LAYERS integration in
core_model_loadingneeds to supportfsdp. This is what's gonna be responsible for loading the weights - distribute_module, which needs tto happen before the load, and is responsible for attaching tthe appropriate hooks for fsdp2.
THis also needs to be explained like the above comment: what is fsdp? -> 1 sharding plan 2. hook plan.
Now if you have to rely on DTensor, you might need to change set_param? or you just apply dtensor conversion post all the loading.
The most important is to test say MIxtral with the dynamic weight loader. The way I see it you'll load all weights on all layers, then shard (discard) some of them.
| if not dist.is_initialized(): | ||
| try: | ||
| rank = int(os.environ["RANK"]) | ||
| local_rank = int(os.environ["LOCAL_RANK"]) | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
|
|
||
| backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"} | ||
| backend = backend_map.get(device_type) | ||
| if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")): | ||
| backend = "ccl" | ||
| if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True): | ||
| backend = "ccl" | ||
|
|
||
| dist.init_process_group(backend=backend, rank=rank, world_size=world_size) | ||
| if device_type != "cpu": | ||
| current_device.set_device(local_rank) | ||
|
|
||
| except Exception as e: | ||
| raise OSError( | ||
| "We tried to initialize torch.distributed for you, but it failed. Make " | ||
| "sure you init torch distributed in your script to use `fsdp_plan`." | ||
| ) from e | ||
|
|
||
| if device_type != "cpu": | ||
| current_device.set_device(int(os.environ["LOCAL_RANK"])) | ||
| index = current_device.current_device() | ||
| fsdp_device = torch.device(device_type, index) | ||
| device_map = fsdp_device | ||
| else: | ||
| fsdp_device = torch.device(device_type) | ||
| device_map = device_type or {} | ||
|
|
||
| fsdp_size = dist.get_world_size() | ||
| device_mesh = torch.distributed.init_device_mesh(fsdp_device.type, (fsdp_size,), mesh_dim_names=("dp_shard",)) |
There was a problem hiding this comment.
can't we re-use the func we defined in tensor_parallel ?
| """ | ||
| Identifies transformer block classes in a model for FSDP wrapping. | ||
| These are typically the repeated layers that benefit from FSDP sharding. | ||
|
|
||
| Returns a set of module classes that should be wrapped with fully_shard(). | ||
| """ |
There was a problem hiding this comment.
pretty sure you can just check if the layer is GradientCheckpointingLayer 😉
| logger.debug(f"Applied fully_shard to {name} ({type(module).__name__})") | ||
|
|
||
|
|
||
| def _find_final_norm(model, decoder_layer_names): |
There was a problem hiding this comment.
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}this looks like something we can define in metadata / take from PP no?
| # Untied: [final_norm, lm_head] | ||
| # Tied: [final_norm, embed_tokens] - embed_tokens.weight IS lm_head.weight. |
There was a problem hiding this comment.
can be taken from:
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
No?
| return strategy != "keep_full_weight", mp_policy, offload_policy | ||
|
|
||
|
|
||
| def _iter_manual_plan_targets(model, pattern, name_to_module, already_sharded_names): |
There was a problem hiding this comment.
can you document what this does?
not entirely sure we need it, but if we do, probably something TP plan or PP plan or Dtype plan are gonna be using ?
| fsdp_plan: | ||
| Explicit FSDP config dict with a required "mode" key. | ||
|
|
||
| Auto mode: | ||
| fsdp_plan = {"mode": "auto"} | ||
|
|
||
| Auto mode with optional policies: | ||
| fsdp_plan = {"mode": "auto", "cpu_offload": False, "mixed_precision": True} | ||
|
|
||
| Manual mode: | ||
| fsdp_plan = { | ||
| "mode": "manual", | ||
| "modules": { | ||
| "model.embed_tokens": ["free_full_weight"], | ||
| "model.layers.0.self_attn": ["free_full_weight", "cpu_offload", "mixed_precision"], | ||
| "model.layers.0.mlp": ["free_full_weight"], | ||
| "model.norm": ["keep_full_weight"], | ||
| "lm_head": ["keep_full_weight"], | ||
| }, | ||
| } | ||
| """ |
There was a problem hiding this comment.
Yep, I think the idea was to stay simple like the pp / tp plan. But for fsdp we might want more control.
|
Based on my CP experiments:
# _parse_fsdp_plan_mode could normalize strings:
if isinstance(fsdp_plan, str):
fsdp_plan = {"mode": fsdp_plan}
# Users need to do this:
world_mesh = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp"))
fsdp_mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
model = AutoModelForCausalLM.from_pretrained(..., fsdp_device_mesh=fsdp_mesh, fsdp_plan={"mode": "auto"})The PR's docstring or the
|
…dConfig) - Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan - Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig - Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping - Rewire from_pretrained with two clean separated paths: 1. distributed_config → native torch.distributed (no accelerate) 2. Everything else → accelerate (unchanged) - Export DistributedConfig from top-level transformers package - Add unit tests for DistributedConfig
…dConfig) - Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan - Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig - Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping - Rewire from_pretrained with two clean separated paths: 1. distributed_config → native torch.distributed (no accelerate) 2. Everything else → accelerate (unchanged) - Export DistributedConfig from top-level transformers package - Add unit tests for DistributedConfig
…dConfig) - Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan - Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig - Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping - Rewire from_pretrained with two clean separated paths: 1. distributed_config → native torch.distributed (no accelerate) 2. Everything else → accelerate (unchanged) - Export DistributedConfig from top-level transformers package - Add unit tests for DistributedConfig
…dConfig) - Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan - Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig - Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping - Rewire from_pretrained with two clean separated paths: 1. distributed_config → native torch.distributed (no accelerate) 2. Everything else → accelerate (unchanged) - Export DistributedConfig from top-level transformers package - Add unit tests for DistributedConfig
* feat: from_pretrained distributed refactor (FSDP2 + TP via DistributedConfig) - Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan - Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig - Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping - Rewire from_pretrained with two clean separated paths: 1. distributed_config → native torch.distributed (no accelerate) 2. Everything else → accelerate (unchanged) - Export DistributedConfig from top-level transformers package - Add unit tests for DistributedConfig * Convert DistributedConig to dict for JSON serialization * some fixes * linting * linting * freaking linting again * some fixes for CI * linting * fix tests * linting * fix tp tests
- Add apply_fully_shard_data_parallel() with auto/manual mode block detection - FSDP vs DDP loss/grad parity tests - Distributed test helpers (testing_utils.py) - is_fsdp_enabled(), is_fsdp_managed_module() utilities - Minimal FSDP hooks in from_pretrained - FSDP-aware flash attention check
- train_fsdp_tp.py: minimal FSDP+TP training example - train_fsdp_tp_torchtitan_style.py: torchtitan-style training example - verify_loading.py: save/load roundtrip verification - run_compare.sh: FSDP+TP vs FSDP-only comparison - run_verify_all.sh: run verification across all modes - tmp_generate.py: quick generation test
…sformers into distributed_api
|
[For maintainers] Suggested jobs to run (before merge) run-slow: clap, deit |
- Re-export is_fsdp_enabled and is_fsdp_managed_module from integrations/fsdp.py (moved to distributed/utils.py) - Remove unused # type: ignore comments in generation/utils.py
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44083&sha=37dcc1 |


This PR introduces first-class FSDP2 (Fully Sharded Data Parallel v2) support directly in Transformers, bypassing the need for Accelerate's FSDP wrapper. It covers the full lifecycle: model distribution, training, checkpointing, and CI testing across dozens of models.
A standalone script for usage and expected throughput will be available at https://github.com/huggingface/distributed-training-cookbook
The correctness has been tested to match the Torchtitan implementation (cf https://github.com/huggingface/torchtitan/blob/sanity-check-fsdp/torchtitan/experiments/debug_fsdp/README.md)
NOTE: Transformers modeling do take more memory for some reasons. To investigate later
1. Native FSDP2 Integration (
src/transformers/integrations/fsdp.py)The core addition is a new FSDP integration module that provides:
initialize_fsdp()-- Sets up theDeviceMeshand process group for FSDP2 (requires PyTorch >= 2.5). Handles automatic backend detection (NCCL, GLOO, XCCL, etc.) and device assignment.apply_fsdp2()with two modes:{"mode": "auto"}) -- Automatically discovers transformer block classes (DecoderLayer, EncoderLayer, etc.), shards input embeddings, all transformer blocks, and groups the final norm + output head together. Supports optionalcpu_offloadandmixed_precisionpolicies.{"mode": "manual", "modules": {...}}) -- Lets users specify exactly which modules to shard, with per-module options like"free_full_weight","keep_full_weight","cpu_offload", and"mixed_precision".Smart block detection (
get_transformer_block_classes()) -- Finds transformer block classes by name pattern and filters out nested blocks (e.g., MoeBlock inside DecoderLayer) to only FSDP-wrap the outermost ones. This enables MoE model support.Tied weight handling -- Properly handles weight tying (e.g.,
lm_head.weight==embed_tokens.weight) by grouping tied modules and re-tying afterfully_shardreplaces parameters with DTensors.2. FSDP2-Aware Save/Load via DCP + Safetensors
save_fsdp_model()-- Saves FSDP2 model weights using PyTorch's Distributed Checkpoint (DCP) withHuggingFaceStorageWriter, enabling parallel distributed save with automatic consolidation into standard HF-compatible safetensors files.save_pretrained()integration --PreTrainedModel.save_pretrained()now detects FSDP2 models (_is_fsdp_managed_module) and automatically routes to the DCP save path.from_pretrained()integration -- Accepts newfsdp_planandfsdp_device_meshkwargs. After loading weights, it applies FSDP2 distribution viadistribute_fsdp_model().3. Comprehensive FSDP Test Suite (
tests/test_fsdp_mixin.py)A new
FSDPTesterMixinclass is added to the standard test infrastructure, automatically inherited by allCausalLMModelTestclasses. It includes 7 batched subtests per model, all run on CPU throughgloobackend +mp.spawn:sharding_structure_untied/tiedauto_plan_vs_ddp(untied/tied)manual_plan_vs_ddp(untied/tied)save_loadsave_pretrained+ reload viafrom_pretrainedproduces bit-exact weightsThe tests also validate checkpoint resumability: train for N/2 steps, save checkpoint (model via DCP+safetensors, optimizer+RNG via distcp), load into a fresh model, continue training, and verify the full trace matches an uninterrupted DDP run.
4. CI Test for Broad Model Coverage
Two bash scripts run the FSDP mixin tests across many models in parallel:
Dense models: Tests ~10 active models (GPT-2, Qwen3, Phi, Llama, ModernBERT-decoder, OLMo3, Phi3, Mistral, LFM2, Qwen3.5) out of 40 total, ranked by HuggingFace Hub downloads.
MoE models: Tests ~10 active MoE models (GPT-OSS, GLM-MoE-DSA, Qwen3-MoE, GLM4-MoE-Lite, Qwen3.5-MoE, DeepSeek-V2, Qwen3-Next, Mixtral, Qwen2-MoE, PhiMoE) out of 24 total.