[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766
[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766tugsbayasgalan wants to merge 34 commits intogh/tugsbayasgalan/11/basefrom
Conversation
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: 053b6db
Pull Request resolved: #2766
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: 9dc00d2
Pull Request resolved: #2766
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: b038c22
Pull Request resolved: #2766
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: be8f4ee
Pull Request resolved: #2766
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: 21edf47
Pull Request resolved: #2766
| # Rematerialize activations tagged PREFER_RECOMPUTE by selective AC. | ||
| # Duplicates recomputable forward ops before the backward region and | ||
| # DCEs the original copies, reducing peak memory. | ||
| traced = remat_using_tags_for_fwd_loss_bwd_graph(traced) |
There was a problem hiding this comment.
No, user should apply this if they want.
keep the tracing simple.
There was a problem hiding this comment.
Isn't it kinda weird that user annotates nodes with AC and don't see any improvement in the end? It feels like silent incorrectness to me.
| for node in traced.graph.nodes: | ||
| ac_id = node.meta.get("custom", {}).get("ac_graph_id") | ||
| if ac_id is not None: | ||
| node.meta["ac_graph_id"] = ac_id |
There was a problem hiding this comment.
again, remove this. should be a post processing pass in user's control
|
|
||
|
|
||
| @contextmanager | ||
| def _patch_checkpoint_wrapper_ac_graph_id() -> Generator[None, None, None]: |
There was a problem hiding this comment.
No, let's have a clear separation of "tracer code" and "trainer code".
ac related optimization is clearly a user decision, and should live in the trainer code.
SherlockNoMad
left a comment
There was a problem hiding this comment.
The only change needed is
with torch.no_grad():
flat_outputs = self.gm(*flat_inputs)
All other changes shouldn't live in tracer code.
| def _patched_forward(self, *args, **kwargs): | ||
| global _ac_graph_id_counter | ||
| _ac_graph_id_counter += 1 | ||
| with torch.fx.traceback.annotate({"ac_graph_id": _ac_graph_id_counter}): |
There was a problem hiding this comment.
@tugsbayasgalan can you check if we don't do anything - does the tracer do the right thing to just recompute the forward into the backward? Because if yes, then we dont need anything.
There was a problem hiding this comment.
Here is the full summary: pytorch/pytorch#178935
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass for post-hoc tagging, strip recompute tags from backward nodes, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 49ff06c Pull Request resolved: #2766
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass for post-hoc tagging, strip recompute tags from backward nodes, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 1ff7e16 Pull Request resolved: #2766
| from torchtitan.trainer import Trainer | ||
|
|
||
|
|
||
| def _strip_recompute_from_backward_nodes(gm: torch.fx.GraphModule) -> None: |
There was a problem hiding this comment.
Not sure if this function should stay in trainer.py
There was a problem hiding this comment.
moved to passes
| node.meta.pop("ac_graph_id", None) | ||
|
|
||
|
|
||
| def apply_ac_remat_pass(traced: TracedResult) -> None: |
There was a problem hiding this comment.
Similarly, this shouldn't be in trainer.py, we have a dedicated passes.py in experiments/graph_trainer/, which contains the passes for the traditional JIT and AOT graphs (dynamo+AOT Autograd).
I imagine the passes will be different for the make_fx tracer, but they might share common components.
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. [ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. [ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. [ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 4f5edea Pull Request resolved: #2766
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 5c30f21 Pull Request resolved: #2766
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between aot and aot_fx_trace <img width="918" height="750" alt="Screenshot 2026-04-12 at 11 39 01 PM" src="https://github.com/user-attachments/assets/059188c5-fc14-4a34-876a-966d938bef3b" /> [ghstack-poisoned]
|
@claude Give review.
|
|
Claude finished @tugsbayasgalan's task in 3m 42s —— View job Review of PR #2766 — Answers to your 4 questions
Q1: Is the SAC peak test minimal enough?The test ( Good:
Issues:
Overall: the test is minimal enough for its purpose. The main suggestion is to unify the Q2: Am I following torchtitan coding format?Mostly yes, with a few items: Good:
Issues:
Q3: Does the backward phase tagging make sense?The approach is sound and well-documented. Here's the analysis: The mechanism:
Good:
Concern — Reviewer feedback on naming: @SherlockNoMad's feedback was:
Your current implementation uses One edge case to consider: If a model has nested Q4: Does the SAC pass plumbing make sense?The flow:
Good:
Issues:
Summary
Top 3 actionable items:
|
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between aot and aot_fx_trace <img width="918" height="750" alt="Screenshot 2026-04-12 at 11 39 01 PM" src="https://github.com/user-attachments/assets/059188c5-fc14-4a34-876a-966d938bef3b" /> [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: ec18b1d Pull Request resolved: #2766
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between aot and aot_fx_trace <img width="918" height="750" alt="Screenshot 2026-04-12 at 11 39 01 PM" src="https://github.com/user-attachments/assets/059188c5-fc14-4a34-876a-966d938bef3b" /> [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 7561d24 Pull Request resolved: #2766
Stack from ghstack (oldest at bottom):
Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.
Adds SAC option on aot_fx_trace on GraphTrainer.
Adds option to dump the peak memory from CUDA caching allocator.
Adds similar test as test_bitwise_equivalency.py for peak memory
Here is the run comparison between aot and aot_fx_trace