Skip to content

FSDP2 native support in transformers #44083

Open
3outeille wants to merge 15 commits intomainfrom
fsdp-vs-ddp
Open

FSDP2 native support in transformers #44083
3outeille wants to merge 15 commits intomainfrom
fsdp-vs-ddp

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Feb 17, 2026

  • TODO:
  • fsdp => faire comme tp en mode fsdp_plan manual qui devient l'auto par défaut

This PR introduces first-class FSDP2 (Fully Sharded Data Parallel v2) support directly in Transformers, bypassing the need for Accelerate's FSDP wrapper. It covers the full lifecycle: model distribution, training, checkpointing, and CI testing across dozens of models.

A standalone script for usage and expected throughput will be available at https://github.com/huggingface/distributed-training-cookbook

TODO: add example train_fsdp.py in the PR

The correctness has been tested to match the Torchtitan implementation (cf https://github.com/huggingface/torchtitan/blob/sanity-check-fsdp/torchtitan/experiments/debug_fsdp/README.md)

image

NOTE: Transformers modeling do take more memory for some reasons. To investigate later

1. Native FSDP2 Integration (src/transformers/integrations/fsdp.py)

The core addition is a new FSDP integration module that provides:

  • initialize_fsdp() -- Sets up the DeviceMesh and process group for FSDP2 (requires PyTorch >= 2.5). Handles automatic backend detection (NCCL, GLOO, XCCL, etc.) and device assignment.

  • apply_fsdp2() with two modes:

    • Auto mode ({"mode": "auto"}) -- Automatically discovers transformer block classes (DecoderLayer, EncoderLayer, etc.), shards input embeddings, all transformer blocks, and groups the final norm + output head together. Supports optional cpu_offload and mixed_precision policies.
    • Manual mode ({"mode": "manual", "modules": {...}}) -- Lets users specify exactly which modules to shard, with per-module options like "free_full_weight", "keep_full_weight", "cpu_offload", and "mixed_precision".
  • Smart block detection (get_transformer_block_classes()) -- Finds transformer block classes by name pattern and filters out nested blocks (e.g., MoeBlock inside DecoderLayer) to only FSDP-wrap the outermost ones. This enables MoE model support.

  • Tied weight handling -- Properly handles weight tying (e.g., lm_head.weight == embed_tokens.weight) by grouping tied modules and re-tying after fully_shard replaces parameters with DTensors.

2. FSDP2-Aware Save/Load via DCP + Safetensors

  • save_fsdp_model() -- Saves FSDP2 model weights using PyTorch's Distributed Checkpoint (DCP) with HuggingFaceStorageWriter, enabling parallel distributed save with automatic consolidation into standard HF-compatible safetensors files.

  • save_pretrained() integration -- PreTrainedModel.save_pretrained() now detects FSDP2 models (_is_fsdp_managed_module) and automatically routes to the DCP save path.

  • from_pretrained() integration -- Accepts new fsdp_plan and fsdp_device_mesh kwargs. After loading weights, it applies FSDP2 distribution via distribute_fsdp_model().

3. Comprehensive FSDP Test Suite (tests/test_fsdp_mixin.py)

A new FSDPTesterMixin class is added to the standard test infrastructure, automatically inherited by all CausalLMModelTest classes. It includes 7 batched subtests per model, all run on CPU through gloo backend + mp.spawn:

Test What it validates
sharding_structure_untied/tied Correct FSDP wrapping targets match expectations
auto_plan_vs_ddp (untied/tied) FSDP2 auto mode produces identical losses, grad norms, and final weights as DDP
manual_plan_vs_ddp (untied/tied) FSDP2 manual mode matches DDP
save_load Save via save_pretrained + reload via from_pretrained produces bit-exact weights

The tests also validate checkpoint resumability: train for N/2 steps, save checkpoint (model via DCP+safetensors, optimizer+RNG via distcp), load into a fresh model, continue training, and verify the full trace matches an uninterrupted DDP run.

4. CI Test for Broad Model Coverage

Two bash scripts run the FSDP mixin tests across many models in parallel:

  • Dense models: Tests ~10 active models (GPT-2, Qwen3, Phi, Llama, ModernBERT-decoder, OLMo3, Phi3, Mistral, LFM2, Qwen3.5) out of 40 total, ranked by HuggingFace Hub downloads.

  • MoE models: Tests ~10 active MoE models (GPT-OSS, GLM-MoE-DSA, Qwen3-MoE, GLM4-MoE-Lite, Qwen3.5-MoE, DeepSeek-V2, Qwen3-Next, Mixtral, Qwen2-MoE, PhiMoE) out of 24 total.

@3outeille 3outeille changed the title Add distributed training CI job to CircleCI configuration FSDP native support in transformers Feb 17, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille changed the base branch from v5-distributed-training-ci to main March 11, 2026 11:18
@3outeille 3outeille marked this pull request as ready for review March 11, 2026 13:51
@github-actions github-actions Bot requested review from SunMarc and ydshieh March 11, 2026 13:51
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left a couple of comments !

Comment thread src/transformers/modeling_utils.py Outdated
Comment thread tests/test_fsdp_mixin.py
@@ -0,0 +1,861 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's create a new folder called tests/training and put those there instead. It will be better I think

Comment on lines +1 to +5
#!/bin/bash

# Script to run all FSDP mixin tests for dense models in parallel.
# Work in tandem with a special test_fsdp_mixin.py that batches all 11 distributed tests in a single mp.spawn. (will not be committed)
# Uses concurrency-limited dispatch: multiple models share GPU pairs since test models are tiny (~7 MiB).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even the fsdp folder we can move that there

Comment thread src/transformers/integrations/fsdp.py Outdated
Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +378 to +398
fsdp_plan:
Explicit FSDP config dict with a required "mode" key.

Auto mode:
fsdp_plan = {"mode": "auto"}

Auto mode with optional policies:
fsdp_plan = {"mode": "auto", "cpu_offload": False, "mixed_precision": True}

Manual mode:
fsdp_plan = {
"mode": "manual",
"modules": {
"model.embed_tokens": ["free_full_weight"],
"model.layers.0.self_attn": ["free_full_weight", "cpu_offload", "mixed_precision"],
"model.layers.0.mlp": ["free_full_weight"],
"model.norm": ["keep_full_weight"],
"lm_head": ["keep_full_weight"],
},
}
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it could make sense to have a nice dataclass for fsdp_plan instead of a dict ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think the idea was to stay simple like the pp / tp plan. But for fsdp we might want more control.

Comment on lines +371 to +376
def apply_fsdp2(
model,
device_mesh,
fsdp_plan: dict[str, Any] | None,
):
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to check how it integrates well with trainer. We can pass fsdp and fsdp_config in training_args and we would have to do the mapping to create the correct fsdp_plan + we need to call apply_fsdp2 instead of prepare() on the model.

Comment thread src/transformers/modeling_utils.py Outdated
Comment thread src/transformers/testing_utils.py Outdated
from packaging import version

#TODO(3outeille): guarding to protect against missing import
from torch.distributed.device_mesh import init_device_mesh
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional torch imports break non-torch environments

High Severity

Top-level import torch, import torch.distributed, import torch.multiprocessing, and from torch.distributed.device_mesh import init_device_mesh are added unconditionally at module level. This will cause an ImportError for anyone importing from testing_utils in an environment without PyTorch (or with an older PyTorch lacking device_mesh). The existing file already uses is_torch_available() guards for other torch imports — these new ones need the same treatment.

Fix in Cursor Fix in Web

Comment thread tests/causal_lm_tester.py Outdated
)
from .test_pipeline_mixin import PipelineTesterMixin
from .test_tensor_parallel_mixin import TensorParallelTesterMixin
from .test_fsdp_mixin import FSDPTesterMixin
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate FSDPTesterMixin import in causal_lm_tester

Low Severity

from .test_fsdp_mixin import FSDPTesterMixin appears twice — once at line 32 and again at line 43. One of these is redundant and was likely left in by mistake during development.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Comment thread src/transformers/integrations/fsdp.py
@3outeille 3outeille changed the title FSDP native support in transformers FSDP2 native support in transformers Mar 18, 2026
@3outeille 3outeille requested review from ArthurZucker and SunMarc and removed request for ydshieh March 18, 2026 20:58
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look at the tests next time but my main comment:

You are not integrated with the core-model-loading API noo?

My understanding is that you should shard the weights exactly the same way we do for TP, maybe push further as you want to shard ALL layers, then when running the forward, each process materialize the weights locally (GP0-8 end up with the same full tensor) and discard it keeping only its slice.

We need 2 points:

  1. TENSOR_PARALLEL_LAYERS integration in core_model_loading needs to support fsdp. This is what's gonna be responsible for loading the weights
  2. distribute_module, which needs tto happen before the load, and is responsible for attaching tthe appropriate hooks for fsdp2.
    THis also needs to be explained like the above comment: what is fsdp? -> 1 sharding plan 2. hook plan.

Now if you have to rely on DTensor, you might need to change set_param? or you just apply dtensor conversion post all the loading.

The most important is to test say MIxtral with the dynamic weight loader. The way I see it you'll load all weights on all layers, then shard (discard) some of them.

Comment thread src/transformers/integrations/fsdp.py Outdated
Comment thread src/transformers/integrations/fsdp.py Outdated
Comment thread src/transformers/integrations/fsdp.py Outdated
Comment on lines +99 to +132
if not dist.is_initialized():
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
backend = backend_map.get(device_type)
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
backend = "ccl"
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
backend = "ccl"

dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
if device_type != "cpu":
current_device.set_device(local_rank)

except Exception as e:
raise OSError(
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `fsdp_plan`."
) from e

if device_type != "cpu":
current_device.set_device(int(os.environ["LOCAL_RANK"]))
index = current_device.current_device()
fsdp_device = torch.device(device_type, index)
device_map = fsdp_device
else:
fsdp_device = torch.device(device_type)
device_map = device_type or {}

fsdp_size = dist.get_world_size()
device_mesh = torch.distributed.init_device_mesh(fsdp_device.type, (fsdp_size,), mesh_dim_names=("dp_shard",))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we re-use the func we defined in tensor_parallel ?

Comment on lines +149 to +154
"""
Identifies transformer block classes in a model for FSDP wrapping.
These are typically the repeated layers that benefit from FSDP sharding.

Returns a set of module classes that should be wrapped with fully_shard().
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure you can just check if the layer is GradientCheckpointingLayer 😉

class GradientCheckpointingLayer(nn.Module):
we make all blocks inherit from this!

Comment thread src/transformers/integrations/fsdp.py Outdated
logger.debug(f"Applied fully_shard to {name} ({type(module).__name__})")


def _find_final_norm(model, decoder_layer_names):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    base_model_pp_plan = {
        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
        "norm": (["hidden_states"], ["hidden_states"]),
    }

this looks like something we can define in metadata / take from PP no?

Comment on lines +266 to +267
# Untied: [final_norm, lm_head]
# Tied: [final_norm, embed_tokens] - embed_tokens.weight IS lm_head.weight.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be taken from:

    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

No?

return strategy != "keep_full_weight", mp_policy, offload_policy


def _iter_manual_plan_targets(model, pattern, name_to_module, already_sharded_names):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you document what this does?
not entirely sure we need it, but if we do, probably something TP plan or PP plan or Dtype plan are gonna be using ?

Comment on lines +378 to +398
fsdp_plan:
Explicit FSDP config dict with a required "mode" key.

Auto mode:
fsdp_plan = {"mode": "auto"}

Auto mode with optional policies:
fsdp_plan = {"mode": "auto", "cpu_offload": False, "mixed_precision": True}

Manual mode:
fsdp_plan = {
"mode": "manual",
"modules": {
"model.embed_tokens": ["free_full_weight"],
"model.layers.0.self_attn": ["free_full_weight", "cpu_offload", "mixed_precision"],
"model.layers.0.mlp": ["free_full_weight"],
"model.norm": ["keep_full_weight"],
"lm_head": ["keep_full_weight"],
},
}
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think the idea was to stay simple like the pp / tp plan. But for fsdp we might want more control.

@kashif
Copy link
Copy Markdown
Contributor

kashif commented Mar 20, 2026

Based on my CP experiments:

  1. Accept fsdp_plan="auto" string shorthand: Currently fsdp_plan={"mode": "auto"} works but fsdp_plan="auto" raises a confusing error. Since tp_plan="auto" is accepted as a string, fsdp_plan should be consistent:
  # _parse_fsdp_plan_mode could normalize strings:
  if isinstance(fsdp_plan, str):
      fsdp_plan = {"mode": fsdp_plan}
  1. Document CP + FSDP2 device mesh pattern: When using Context Parallelism with native FSDP2, the fsdp_device_mesh must be the flattened dp_cp mesh, not just the dp submesh otherwise FSDP2 doesn't shard parameters across CP ranks and you OOM. This is non-obvious:
  # Users need to do this:
  world_mesh = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp"))
  fsdp_mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
  model = AutoModelForCausalLM.from_pretrained(..., fsdp_device_mesh=fsdp_mesh, fsdp_plan={"mode": "auto"})

The PR's docstring or the 3D_parallel.py example should show this pattern.

  1. Add a CP example using native FSDP2: The current 3D_parallel.py uses FSDP1 (FullyShardedDataParallel with NO_SHARD). An example showing native FSDP2 (fsdp_plan) + CP would be good to have.

  2. Note about CP rotation method: The default CP rotation (allgather) requires allgather_into_tensor_coalesced which may not be available in all PyTorch builds. The alltoall rotation works universally. Worth noting in the docs or defaulting to alltoall.

3outeille added a commit that referenced this pull request Mar 25, 2026
…dConfig)

- Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan
- Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig
- Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping
- Rewire from_pretrained with two clean separated paths:
  1. distributed_config → native torch.distributed (no accelerate)
  2. Everything else → accelerate (unchanged)
- Export DistributedConfig from top-level transformers package
- Add unit tests for DistributedConfig
3outeille added a commit that referenced this pull request Mar 25, 2026
…dConfig)

- Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan
- Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig
- Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping
- Rewire from_pretrained with two clean separated paths:
  1. distributed_config → native torch.distributed (no accelerate)
  2. Everything else → accelerate (unchanged)
- Export DistributedConfig from top-level transformers package
- Add unit tests for DistributedConfig
3outeille added a commit that referenced this pull request Mar 25, 2026
…dConfig)

- Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan
- Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig
- Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping
- Rewire from_pretrained with two clean separated paths:
  1. distributed_config → native torch.distributed (no accelerate)
  2. Everything else → accelerate (unchanged)
- Export DistributedConfig from top-level transformers package
- Add unit tests for DistributedConfig
@3outeille 3outeille changed the base branch from main to distributed_api March 25, 2026 15:42
3outeille added a commit that referenced this pull request Mar 25, 2026
…dConfig)

- Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan
- Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig
- Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping
- Rewire from_pretrained with two clean separated paths:
  1. distributed_config → native torch.distributed (no accelerate)
  2. Everything else → accelerate (unchanged)
- Export DistributedConfig from top-level transformers package
- Add unit tests for DistributedConfig
3outeille added a commit that referenced this pull request Mar 26, 2026
* feat: from_pretrained distributed refactor (FSDP2 + TP via DistributedConfig)

- Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan
- Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig
- Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping
- Rewire from_pretrained with two clean separated paths:
  1. distributed_config → native torch.distributed (no accelerate)
  2. Everything else → accelerate (unchanged)
- Export DistributedConfig from top-level transformers package
- Add unit tests for DistributedConfig

* Convert DistributedConig to dict for JSON serialization

* some fixes

* linting

* linting

* freaking linting again

* some fixes for CI

* linting

* fix tests

* linting

* fix tp tests
@ArthurZucker ArthurZucker changed the base branch from distributed_api to main April 10, 2026 08:35
@3outeille 3outeille changed the base branch from main to distributed_api April 13, 2026 13:33
3outeille and others added 3 commits April 13, 2026 15:33
- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check
3outeille and others added 9 commits April 13, 2026 16:34
- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: clap, deit

- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44083&sha=37dcc1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants