[GraphTrainer][AutoDev] Add integration test for fsdp_reshard_after_fwd joint pass#2911
[GraphTrainer][AutoDev] Add integration test for fsdp_reshard_after_fwd joint pass#2911SherlockNoMad wants to merge 2 commits intomainfrom
Conversation
…wd joint pass Port from PR #2809: adds an AOT mode integration test that exercises the `--parallelism.fsdp_reshard_after_forward always` flag with FSDP+TP (dp_shard=4, tp=2) on 8 GPUs. This ensures the reshard-after-forward code path is covered by GraphTrainer CI.
| [ | ||
| "--module graph_trainer.llama3", | ||
| "--config graph_trainer_llama3_debugmodel", | ||
| "--compile.mode aot", |
There was a problem hiding this comment.
you should do this for aot_fx_trace mode, we are about to deprecate aot mode.
current GraphTrainer pass haven't implement fsdp related passes, but it should .
Give a shot at enabling fsdp_bucketing and resharding, let me know what doesn't work.
GraphTrainer is running pass on forward_loss_backward joint graph, so you will need to adjust existing fsdp passes for that.
There was a problem hiding this comment.
AutoDev: Moved the test from AOT to aot_fx_trace mode. Here are my findings on enabling FSDP passes in aot_fx_trace mode:
fsdp_reshard_after_fwd_pass does not work in aot_fx_trace mode. The existing pass (annotate_fsdp_all_gather) annotates forward all-gather nodes with CheckpointPolicy.MUST_RECOMPUTE/MUST_SAVE metadata. These annotations are consumed by AOT Autograd's min-cut partitioner to decide whether to keep all-gathered weights alive or re-materialize them in the backward. In aot_fx_trace mode, make_fx traces the entire fwd+loss+bwd as a single flat graph with no partitioner — so the annotations have no consumer and are effectively a no-op.
The fundamental difference: in AOT mode, the "reshard after forward" decision is a partitioning choice (save vs. recompute). In aot_fx_trace mode, the backward is already baked into the graph. Making reshard-after-forward work here would require a different approach — identifying forward all-gather outputs used in the backward and inserting explicit free + re-gather operations in the traced graph.
fsdp_bucketing passes (auto_bucketing / transformer_block_bucketing) should conceptually work on the joint graph. They operate on FX graphs by finding collective ops (all-gather, reduce-scatter) and reordering nodes for comm/compute overlap. They don't depend on the graph being a partitioned fwd or bwd graph. However, they are currently only wired up in the AOT compiler pass pipeline (AVAILABLE_COMPILER_PASSES), not in construct_default_graph_passes for aot_fx_trace. Wiring them in should be straightforward.
What would need to change to properly support reshard_after_forward in aot_fx_trace:
A new pass would need to operate on the joint graph to:
- Find forward all-gather + wait_tensor chains from SimpleFSDP (same pattern as
is_wait_tensor_from_fsdp) - For each such chain, identify all backward uses of the all-gathered tensor
- Insert a new all-gather + wait_tensor pair just before the backward uses, and rewire the backward consumers to use the new all-gather output
- This effectively duplicates the forward all-gather in the backward portion of the graph, allowing the forward copy to be freed after forward completes
The test as committed will pass (the config flag is silently ignored in aot_fx_trace), which serves as a placeholder for when the proper pass is implemented.
…aot_fx_trace mode
Summary
--parallelism.fsdp_reshard_after_forward alwayswith FSDP+TP (dp_shard=4, tp=2) on 8 GPUsTest plan
aot_llama3_fsdp_tp_reshard_after_fwdintegration test with 8 GPUs