Skip to content

[Trainer] Add ddp_static_graph option#45519

Merged
SunMarc merged 2 commits intohuggingface:mainfrom
KeitaW:feat/ddp-static-graph
Apr 21, 2026
Merged

[Trainer] Add ddp_static_graph option#45519
SunMarc merged 2 commits intohuggingface:mainfrom
KeitaW:feat/ddp-static-graph

Conversation

@KeitaW
Copy link
Copy Markdown
Contributor

@KeitaW KeitaW commented Apr 20, 2026

What does this PR do?

Exposes PyTorch DDP's static_graph flag via a new ddp_static_graph: Optional[bool] field on TrainingArguments, forwarded through Trainer._build_accelerator_args into Accelerate's DistributedDataParallelKwargs (which already supports it; only the Transformers-side plumbing was missing).

This completes the set of DDP flags already partially exposed on TrainingArguments (ddp_find_unused_parameters, ddp_bucket_cap_mb, ddp_broadcast_buffers). Today a user can configure nearly everything about DDP except static_graph, and today's only workarounds are monkey-patching DistributedDataParallel.__init__ via a sitecustomize.py shim or subclassing Trainer to override _wrap_model — neither portable.

Fixes #45518

Why

Per PyTorch's DDP docs, static_graph=True relaxes several DDP reducer constraints for users who can guarantee a stable graph across iterations: "Reentrant backwards", "Activation checkpointing when model has unused parameters", "There are model parameters that are outside of forward function", and "Potentially improve performance when there are unused parameters."

A common HF Trainer scenario where this matters: a model with trainable parameters that don't contribute to loss on every iteration (e.g. frozen submodules, or multi-head models where only one head is trained). Under DDP with ddp_find_unused_parameters=False (the Trainer default), such a model fails at iter 1 with:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss.

The common workaround is ddp_find_unused_parameters=True, but that forces DDP to traverse the autograd graph on every iteration to find unused params — a measurable per-step cost. static_graph=True is the performance-optimal alternative (PyTorch's own note: "potentially improve performance when there are unused parameters"): DDP records the participating-parameter set on iter 1 and assumes it's stable thereafter.

The earlier blocker that once made static_graph=True unsafe with HF models (ModelOutput subclasses not registered as pytree nodes) was fixed in #25358 (closed #25357, merged 2023-08) and is still live in src/transformers/utils/generic.py — where the repo itself documents static_graph=True safety with ModelOutput in a docstring.

Changes

  • src/transformers/training_args.py — new ddp_static_graph: bool | None = field(default=None, …), mirroring the ddp_broadcast_buffers pattern. Docstring entry explains the supported-use-cases surface and caveats (DDP-only; incompatible with re-entrant activation checkpointing; requires stable graph).
  • src/transformers/trainer.py (_build_accelerator_args) — one new conditional following the existing if self.args.ddp_* is not None: pattern for bucket_cap_mb and broadcast_buffers. When the flag is None (default), the kwarg is never added to ddp_kwargs, so DistributedDataParallelKwargs' own default (False) applies. Strictly additive; no existing behavior changes.
  • tests/trainer/test_trainer.py — new TrainerDDPKwargsTest class with three tests:
    • ddp_static_graph=True → handler has static_graph=True (positive).
    • ddp_static_graph=False → handler has static_graph=False (positive).
    • ddp_static_graph=None (default) → handler preserves Accelerate's default False (regression guard — the conditional in _build_accelerator_args must NOT leak the kwarg when unset).

All three pass locally.

Interaction caveats (documented in help text)

  • DDP-only. The field has no effect under FSDP or DeepSpeed (the conditional only fires on the DDP path, matching the other ddp_* fields).
  • Gradient checkpointing. Using static_graph=True alongside re-entrant activation checkpointing (use_reentrant=True) is unsafe per PyTorch; the docstring warns. Non-reentrant checkpointing (use_reentrant=False) is fine.
  • Requires stable graph. Modules with data-dependent control flow that changes which parameters are touched per iteration are incompatible with static_graph=True by PyTorch's own contract.

Before submitting

Who can review?

@SunMarc — Trainer / Accelerate integration.

Expose PyTorch DDP's `static_graph` flag as a new `ddp_static_graph: Optional[bool]`
field on `TrainingArguments`, forwarded through `Trainer._build_accelerator_args`
into Accelerate's `DistributedDataParallelKwargs` (which already supports it).

Completes the set of DDP flags partially exposed today
(`ddp_find_unused_parameters`, `ddp_bucket_cap_mb`, `ddp_broadcast_buffers`).

Defaults to `None`; when unset, the kwarg is never added to `ddp_kwargs`, so
Accelerate's own default (`False`) applies — strictly additive, no existing
behavior changes.

See issue huggingface#45518 for full motivation: users with frozen trainable submodules
(e.g. the LLM-frozen head-tuning pattern) today either hit `Expected to have
finished reduction in the prior iteration...` or must pay per-iteration
`find_unused_parameters=True` traversal cost. `static_graph=True` is the
performance-optimal third option that Accelerate/PyTorch already support
but that Trainer couldn't expose.

Tests: positive (True, False) plus regression guard (None must not leak the
kwarg).

Fixes huggingface#45518
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Just a nit

Comment thread src/transformers/training_args.py Outdated
Comment on lines +640 to +649
ddp_static_graph (`bool`, *optional*):
When using distributed training, the value of the flag `static_graph` passed to
`DistributedDataParallel`. When set to `True`, DDP assumes the set of used/unused parameters and the
autograd graph topology are stable across iterations; this can resolve `Expected to have finished
reduction in the prior iteration...` errors caused by trainable parameters that don't contribute to the
loss on every step, and is generally cheaper than `ddp_find_unused_parameters=True` (one iteration-1
cost vs. a per-iteration autograd-graph traversal). Has no effect under FSDP or DeepSpeed; incompatible
with re-entrant activation checkpointing (`use_reentrant=True`). See the PyTorch
[`DistributedDataParallel` docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)
for the full list of supported use cases.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep it short like the other args, users interested will check ddp docs for that

Copy link
Copy Markdown
Contributor Author

@KeitaW KeitaW Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — trimmed to two lines matching the ddp_broadcast_buffers / ddp_bucket_cap_mb style. Thanks!

Comment thread src/transformers/training_args.py Outdated
metadata={
"help": (
"When using distributed training, the value of the flag `static_graph` passed to "
"`DistributedDataParallel`. Has no effect under FSDP or DeepSpeed."
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"`DistributedDataParallel`. Has no effect under FSDP or DeepSpeed."
"`DistributedDataParallel`. "

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied — dropped the FSDP/DeepSpeed sentence from the help string. Thanks!

@SunMarc SunMarc enabled auto-merge April 21, 2026 12:22
@SunMarc SunMarc added this pull request to the merge queue Apr 21, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Merged via the queue into huggingface:main with commit 26bb358 Apr 21, 2026
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Expose static_graph DDP flag via TrainingArguments DDP grads not synced when static_graph=True

3 participants