[Trainer] Add ddp_static_graph option#45519
Merged
SunMarc merged 2 commits intohuggingface:mainfrom Apr 21, 2026
Merged
Conversation
Expose PyTorch DDP's `static_graph` flag as a new `ddp_static_graph: Optional[bool]` field on `TrainingArguments`, forwarded through `Trainer._build_accelerator_args` into Accelerate's `DistributedDataParallelKwargs` (which already supports it). Completes the set of DDP flags partially exposed today (`ddp_find_unused_parameters`, `ddp_bucket_cap_mb`, `ddp_broadcast_buffers`). Defaults to `None`; when unset, the kwarg is never added to `ddp_kwargs`, so Accelerate's own default (`False`) applies — strictly additive, no existing behavior changes. See issue huggingface#45518 for full motivation: users with frozen trainable submodules (e.g. the LLM-frozen head-tuning pattern) today either hit `Expected to have finished reduction in the prior iteration...` or must pay per-iteration `find_unused_parameters=True` traversal cost. `static_graph=True` is the performance-optimal third option that Accelerate/PyTorch already support but that Trainer couldn't expose. Tests: positive (True, False) plus regression guard (None must not leak the kwarg). Fixes huggingface#45518
SunMarc
approved these changes
Apr 20, 2026
Comment on lines
+640
to
+649
| ddp_static_graph (`bool`, *optional*): | ||
| When using distributed training, the value of the flag `static_graph` passed to | ||
| `DistributedDataParallel`. When set to `True`, DDP assumes the set of used/unused parameters and the | ||
| autograd graph topology are stable across iterations; this can resolve `Expected to have finished | ||
| reduction in the prior iteration...` errors caused by trainable parameters that don't contribute to the | ||
| loss on every step, and is generally cheaper than `ddp_find_unused_parameters=True` (one iteration-1 | ||
| cost vs. a per-iteration autograd-graph traversal). Has no effect under FSDP or DeepSpeed; incompatible | ||
| with re-entrant activation checkpointing (`use_reentrant=True`). See the PyTorch | ||
| [`DistributedDataParallel` docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) | ||
| for the full list of supported use cases. |
Member
There was a problem hiding this comment.
keep it short like the other args, users interested will check ddp docs for that
Contributor
Author
There was a problem hiding this comment.
Done — trimmed to two lines matching the ddp_broadcast_buffers / ddp_bucket_cap_mb style. Thanks!
| metadata={ | ||
| "help": ( | ||
| "When using distributed training, the value of the flag `static_graph` passed to " | ||
| "`DistributedDataParallel`. Has no effect under FSDP or DeepSpeed." |
Member
There was a problem hiding this comment.
Suggested change
| "`DistributedDataParallel`. Has no effect under FSDP or DeepSpeed." | |
| "`DistributedDataParallel`. " |
Contributor
Author
There was a problem hiding this comment.
Applied — dropped the FSDP/DeepSpeed sentence from the help string. Thanks!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Exposes PyTorch DDP's
static_graphflag via a newddp_static_graph: Optional[bool]field onTrainingArguments, forwarded throughTrainer._build_accelerator_argsinto Accelerate'sDistributedDataParallelKwargs(which already supports it; only the Transformers-side plumbing was missing).This completes the set of DDP flags already partially exposed on
TrainingArguments(ddp_find_unused_parameters,ddp_bucket_cap_mb,ddp_broadcast_buffers). Today a user can configure nearly everything about DDP exceptstatic_graph, and today's only workarounds are monkey-patchingDistributedDataParallel.__init__via asitecustomize.pyshim or subclassingTrainerto override_wrap_model— neither portable.Fixes #45518
Why
Per PyTorch's DDP docs,
static_graph=Truerelaxes several DDP reducer constraints for users who can guarantee a stable graph across iterations: "Reentrant backwards", "Activation checkpointing when model has unused parameters", "There are model parameters that are outside of forward function", and "Potentially improve performance when there are unused parameters."A common HF Trainer scenario where this matters: a model with trainable parameters that don't contribute to loss on every iteration (e.g. frozen submodules, or multi-head models where only one head is trained). Under DDP with
ddp_find_unused_parameters=False(the Trainer default), such a model fails at iter 1 with:The common workaround is
ddp_find_unused_parameters=True, but that forces DDP to traverse the autograd graph on every iteration to find unused params — a measurable per-step cost.static_graph=Trueis the performance-optimal alternative (PyTorch's own note: "potentially improve performance when there are unused parameters"): DDP records the participating-parameter set on iter 1 and assumes it's stable thereafter.The earlier blocker that once made
static_graph=Trueunsafe with HF models (ModelOutputsubclasses not registered as pytree nodes) was fixed in #25358 (closed #25357, merged 2023-08) and is still live insrc/transformers/utils/generic.py— where the repo itself documentsstatic_graph=Truesafety withModelOutputin a docstring.Changes
src/transformers/training_args.py— newddp_static_graph: bool | None = field(default=None, …), mirroring theddp_broadcast_bufferspattern. Docstring entry explains the supported-use-cases surface and caveats (DDP-only; incompatible with re-entrant activation checkpointing; requires stable graph).src/transformers/trainer.py(_build_accelerator_args) — one new conditional following the existingif self.args.ddp_* is not None:pattern forbucket_cap_mbandbroadcast_buffers. When the flag isNone(default), the kwarg is never added toddp_kwargs, soDistributedDataParallelKwargs' own default (False) applies. Strictly additive; no existing behavior changes.tests/trainer/test_trainer.py— newTrainerDDPKwargsTestclass with three tests:ddp_static_graph=True→ handler hasstatic_graph=True(positive).ddp_static_graph=False→ handler hasstatic_graph=False(positive).ddp_static_graph=None(default) → handler preserves Accelerate's defaultFalse(regression guard — the conditional in_build_accelerator_argsmust NOT leak the kwarg when unset).All three pass locally.
Interaction caveats (documented in help text)
ddp_*fields).static_graph=Truealongside re-entrant activation checkpointing (use_reentrant=True) is unsafe per PyTorch; the docstring warns. Non-reentrant checkpointing (use_reentrant=False) is fine.static_graph=Trueby PyTorch's own contract.Before submitting
Who can review?
@SunMarc — Trainer / Accelerate integration.