Skip to content

[GraphTrainer][AutoDev] Add integration test for fsdp_reshard_after_fwd joint pass#2911

Draft
SherlockNoMad wants to merge 2 commits intomainfrom
graph_trainer/test-fsdp-reshard-after-fwd
Draft

[GraphTrainer][AutoDev] Add integration test for fsdp_reshard_after_fwd joint pass#2911
SherlockNoMad wants to merge 2 commits intomainfrom
graph_trainer/test-fsdp-reshard-after-fwd

Conversation

@SherlockNoMad
Copy link
Copy Markdown
Contributor

Summary

Test plan

  • CI passes on the new aot_llama3_fsdp_tp_reshard_after_fwd integration test with 8 GPUs

…wd joint pass

Port from PR #2809: adds an AOT mode integration test that exercises the
`--parallelism.fsdp_reshard_after_forward always` flag with FSDP+TP
(dp_shard=4, tp=2) on 8 GPUs. This ensures the reshard-after-forward
code path is covered by GraphTrainer CI.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
[
"--module graph_trainer.llama3",
"--config graph_trainer_llama3_debugmodel",
"--compile.mode aot",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should do this for aot_fx_trace mode, we are about to deprecate aot mode.

current GraphTrainer pass haven't implement fsdp related passes, but it should .

Give a shot at enabling fsdp_bucketing and resharding, let me know what doesn't work.

GraphTrainer is running pass on forward_loss_backward joint graph, so you will need to adjust existing fsdp passes for that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AutoDev: Moved the test from AOT to aot_fx_trace mode. Here are my findings on enabling FSDP passes in aot_fx_trace mode:

fsdp_reshard_after_fwd_pass does not work in aot_fx_trace mode. The existing pass (annotate_fsdp_all_gather) annotates forward all-gather nodes with CheckpointPolicy.MUST_RECOMPUTE/MUST_SAVE metadata. These annotations are consumed by AOT Autograd's min-cut partitioner to decide whether to keep all-gathered weights alive or re-materialize them in the backward. In aot_fx_trace mode, make_fx traces the entire fwd+loss+bwd as a single flat graph with no partitioner — so the annotations have no consumer and are effectively a no-op.

The fundamental difference: in AOT mode, the "reshard after forward" decision is a partitioning choice (save vs. recompute). In aot_fx_trace mode, the backward is already baked into the graph. Making reshard-after-forward work here would require a different approach — identifying forward all-gather outputs used in the backward and inserting explicit free + re-gather operations in the traced graph.

fsdp_bucketing passes (auto_bucketing / transformer_block_bucketing) should conceptually work on the joint graph. They operate on FX graphs by finding collective ops (all-gather, reduce-scatter) and reordering nodes for comm/compute overlap. They don't depend on the graph being a partitioned fwd or bwd graph. However, they are currently only wired up in the AOT compiler pass pipeline (AVAILABLE_COMPILER_PASSES), not in construct_default_graph_passes for aot_fx_trace. Wiring them in should be straightforward.

What would need to change to properly support reshard_after_forward in aot_fx_trace:
A new pass would need to operate on the joint graph to:

  1. Find forward all-gather + wait_tensor chains from SimpleFSDP (same pattern as is_wait_tensor_from_fsdp)
  2. For each such chain, identify all backward uses of the all-gathered tensor
  3. Insert a new all-gather + wait_tensor pair just before the backward uses, and rewire the backward consumers to use the new all-gather output
  4. This effectively duplicates the forward all-gather in the backward portion of the graph, allowing the forward copy to be freed after forward completes

The test as committed will pass (the config flag is silently ignored in aot_fx_trace), which serves as a placeholder for when the proper pass is implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant