Skip to content

[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766

Open
tugsbayasgalan wants to merge 34 commits intogh/tugsbayasgalan/11/basefrom
gh/tugsbayasgalan/11/head
Open

[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766
tugsbayasgalan wants to merge 34 commits intogh/tugsbayasgalan/11/basefrom
gh/tugsbayasgalan/11/head

Conversation

@tugsbayasgalan
Copy link
Copy Markdown
Contributor

@tugsbayasgalan tugsbayasgalan commented Mar 31, 2026

Stack from ghstack (oldest at bottom):

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer.

Adds option to dump the peak memory from CUDA caching allocator.

Adds similar test as test_bitwise_equivalency.py for peak memory

Here is the run comparison between aot and aot_fx_trace

Screenshot 2026-04-12 at 11 39 01 PM

…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 31, 2026
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: 053b6db
Pull Request resolved: #2766
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: 9dc00d2
Pull Request resolved: #2766
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: b038c22
Pull Request resolved: #2766
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: be8f4ee
Pull Request resolved: #2766
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: 21edf47
Pull Request resolved: #2766
# Rematerialize activations tagged PREFER_RECOMPUTE by selective AC.
# Duplicates recomputable forward ops before the backward region and
# DCEs the original copies, reducing peak memory.
traced = remat_using_tags_for_fwd_loss_bwd_graph(traced)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, user should apply this if they want.

keep the tracing simple.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it kinda weird that user annotates nodes with AC and don't see any improvement in the end? It feels like silent incorrectness to me.

Comment on lines +550 to +553
for node in traced.graph.nodes:
ac_id = node.meta.get("custom", {}).get("ac_graph_id")
if ac_id is not None:
node.meta["ac_graph_id"] = ac_id
Copy link
Copy Markdown
Contributor

@SherlockNoMad SherlockNoMad Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, remove this. should be a post processing pass in user's control



@contextmanager
def _patch_checkpoint_wrapper_ac_graph_id() -> Generator[None, None, None]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, let's have a clear separation of "tracer code" and "trainer code".

ac related optimization is clearly a user decision, and should live in the trainer code.

Copy link
Copy Markdown
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only change needed is

        with torch.no_grad():
            flat_outputs = self.gm(*flat_inputs)

All other changes shouldn't live in tracer code.

def _patched_forward(self, *args, **kwargs):
global _ac_graph_id_counter
_ac_graph_id_counter += 1
with torch.fx.traceback.annotate({"ac_graph_id": _ac_graph_id_counter}):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tugsbayasgalan can you check if we don't do anything - does the tracer do the right thing to just recompute the forward into the backward? Because if yes, then we dont need anything.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the full summary: pytorch/pytorch#178935

…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass for
post-hoc tagging, strip recompute tags from backward nodes, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 49ff06c
Pull Request resolved: #2766
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass for
post-hoc tagging, strip recompute tags from backward nodes, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 1ff7e16
Pull Request resolved: #2766
from torchtitan.trainer import Trainer


def _strip_recompute_from_backward_nodes(gm: torch.fx.GraphModule) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this function should stay in trainer.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to passes

node.meta.pop("ac_graph_id", None)


def apply_ac_remat_pass(traced: TracedResult) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this shouldn't be in trainer.py, we have a dedicated passes.py in experiments/graph_trainer/, which contains the passes for the traditional JIT and AOT graphs (dynamo+AOT Autograd).

I imagine the passes will be different for the make_fx tracer, but they might share common components.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
… traced execution"

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

[ghstack-poisoned]
@tugsbayasgalan tugsbayasgalan changed the title [graph_trainer] Add remat pass and torch.no_grad() execution to minimal_fx_tracer [graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution Apr 8, 2026
… traced execution"

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

[ghstack-poisoned]
… traced execution"

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

[ghstack-poisoned]
… traced execution"

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 12, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 4f5edea
Pull Request resolved: #2766
… traced execution"

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 12, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 5c30f21
Pull Request resolved: #2766
… traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between aot and aot_fx_trace 

<img width="918" height="750" alt="Screenshot 2026-04-12 at 11 39 01 PM" src="https://github.com/user-attachments/assets/059188c5-fc14-4a34-876a-966d938bef3b" />


[ghstack-poisoned]
@tugsbayasgalan
Copy link
Copy Markdown
Contributor Author

@claude Give review.

  1. Is the sac peak test minimal enough?
  2. Am i following torchtitan coding format?
  3. Does the backward phase tagging make sense?
  4. Is the sac pass plumbing make sense?

@claude
Copy link
Copy Markdown

claude bot commented Apr 13, 2026

Claude finished @tugsbayasgalan's task in 3m 42s —— View job


Review of PR #2766 — Answers to your 4 questions

  • Read all changed files and gather context
  • Q1: Is the SAC peak test minimal enough?
  • Q2: Am I following torchtitan coding format?
  • Q3: Does the backward phase tagging make sense?
  • Q4: Does the SAC pass plumbing make sense?

Q1: Is the SAC peak test minimal enough?

The test (test_sac_peak_memory.py) is structured well and follows the same pattern as test_bitwise_deterministic.py. A few observations:

Good:

  • Warmup step before measurement is correct — avoids allocator and one-time tracing noise
  • Comparing both reserved and active peak memory is thorough
  • Loss and gradient bitwise checks ensure the SAC pass doesn't break correctness
  • The MAX_PEAK_MEMORY_RATIO = 1.10 tolerance is reasonable

Issues:

  1. Divergence from _build_trainer in test_bitwise_deterministic.py — The test sets trainer._fwd_bwd_step_module in the bitwise test but not here (test_sac_peak_memory.py:49-73). The bitwise test sets trainer._fwd_bwd_step_module = None at line 91 of test_bitwise_deterministic.py. If GraphTrainer.forward_backward_step ever accesses _fwd_bwd_step_module, this test would break. Consider keeping both _build_trainer functions consistent, or better yet, extracting a shared helper.

  2. config mock shape differs — In test_bitwise_deterministic.py:84-89, the config mock only has compile with mode and enable_passes. In test_sac_peak_memory.py:60-68, you added passes, joint_passes, and activation_checkpoint. This is fine for now, but these two test helpers are diverging. If both tests evolve independently, the mocked config will drift from the real GraphTrainerCompileConfig.

  3. Memory units label says GiB but uses / 1e9 — At lines 105-106, you divide by 1e9 (which gives GB, not GiB). GiB would be / (1024**3). This is cosmetic since both eager and traced use the same divisor, so the ratio is unaffected, but the variable names (reserved_gib, active_gib) are technically inaccurate.

  4. No torch.use_deterministic_algorithms(True) needed for memory test — The setUp calls _set_deterministic() which enables deterministic algorithms, and tearDown disables it. This is necessary for the bitwise correctness checks, but deterministic mode can introduce overhead on some ops. Since this is a memory test, the deterministic mode could slightly affect peak memory. Not a blocker, just worth being aware of.

  5. Test only covers debugmodel without FlexAttention — The bitwise deterministic test covers multiple model variants (Llama3, DSv3, FlexAttn). The peak memory test only covers Llama3 debugmodel. This is probably fine for a first pass, but worth noting.

Overall: the test is minimal enough for its purpose. The main suggestion is to unify the _build_trainer helper with the one in test_bitwise_deterministic.py to avoid drift.


Q2: Am I following torchtitan coding format?

Mostly yes, with a few items:

Good:

  • Passes live in passes.py (addressed @yiming0416's feedback)
  • enable_graph_ac_for_mode is in common_utils.py which is the right place
  • Function signatures follow the pass convention (gm, example_inputs) -> gm
  • Error messages use ValueError for user-facing config errors

Issues:

  1. hasattr check for config fieldtrainer.py:115:

    if hasattr(self.config, "activation_checkpoint"):

    Per CLAUDE.md, required config fields shouldn't have None defaults, and you should validate config values explicitly. Since GraphTrainer.Config inherits from Trainer.Config, activation_checkpoint should always be present. The hasattr guard suggests uncertainty about whether the field exists. If it should always exist (it's defined in Trainer.Config), remove the hasattr and access it directly. If there's a legitimate case where it's missing (e.g., the test mocks), fix the test mocks instead.

  2. Pass application is still inline in trainer.py — Lines 131-140:

    if self.config.compile.enable_passes:
        passes = []
        if enable_graph_ac:
            passes.append(graph_ac_pass)
        passes.extend(construct_default_graph_passes(self._traced_step))

    @SherlockNoMad and @yiming0416 asked for an optimize(...) abstraction or pass manager pattern like get_compiler_passes_from_config in graph_utils.py. The AC pass is hardcoded via an if check rather than being part of the configurable pass system. This is the main outstanding reviewer feedback item.

  3. maybe_register_blockmask_pytree_node() still called unconditionallytrainer.py:120. You mentioned in the review comments this was accidentally added. It's still present and called unconditionally during tracing setup. If this is intentional (needed for FlexAttention models), it should be documented. If accidental, remove it.

  4. run_tests.py now has a hard dependency on tensorboard — The top-level import at line 11:

    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

    This import runs even when --collect_peak_memory is not passed. If tensorboard is not installed, all integration tests break. This should be a lazy import inside _read_peak_memory().


Q3: Does the backward phase tagging make sense?

The approach is sound and well-documented. Here's the analysis:

The mechanism:

  • trainer.py:45-46: torch.fx.traceback.annotate({"phase": "backward"}) wraps the torch.autograd.grad call
  • _copy_fwd_metadata_to_bw_nodes in make_fx_tracer.py:203-231 propagates forward node's custom metadata to backward nodes sharing the same seq_nr
  • _is_backward_node in passes.py:47-60 checks node.meta.get("custom", {}).get("phase") == "backward"

Good:

  • The comment in _is_backward_node (lines 48-59) explaining why you can't use the ambient autograd_backward traceback marker is excellent. The flex_attention nested backward tracing issue is a real gotcha and this documents it well.
  • Using torch.fx.traceback.annotate is the correct mechanism — it integrates with the existing FX tracing infrastructure rather than inventing a custom annotation system.
  • The annotation flows through _copy_fwd_metadata_to_bw_nodes which copies custom metadata from forward nodes to backward nodes with matching seq_nr, so backward nodes inherit {"phase": "backward"}.

Concern — Reviewer feedback on naming:

@SherlockNoMad's feedback was:

"marking node as backward is generally useful, and doesn't only benefit 'remat_pass_tag'. So I would think of a more general annotation."

Your current implementation uses custom["phase"] == "backward" which is actually more general than the previous remat_pass_tag approach. The key name "phase" is generic and not tied to any specific pass. This seems like a reasonable response to the feedback, though it's worth confirming with the reviewer whether "phase" is general enough or if they had something else in mind (e.g., a dedicated top-level metadata key like node.meta["is_backward"] rather than nested under custom).

One edge case to consider: If a model has nested autograd.grad calls in the forward pass (e.g., for Jacobian computation), those would also get the "phase": "backward" annotation since the annotate context wraps the entire torch.autograd.grad call. This is unlikely in standard training but worth a note.


Q4: Does the SAC pass plumbing make sense?

The flow:

  1. trainer.py:118-119: If AC mode is "selective", annotate_ac_regions(model) marks each transformer block with a unique ac_region_id
  2. trainer.py:133-134: graph_ac_pass is prepended to the pass list
  3. passes.py:116-120: graph_ac_pass delegates to apply_ac_on_fwd_bwd_graph
  4. passes.py:392-408: apply_ac_on_fwd_bwd_graph calls apply_sac_pass(gm) then remat_using_tags_for_fwd_loss_bwd_graph(gm)
  5. passes.py:289-389: apply_sac_pass iterates nodes, skips backward nodes via _is_backward_node, and tags forward nodes with MUST_SAVE/PREFER_RECOMPUTE

Good:

  • The SAC pass now correctly skips backward nodes at the source (passes.py:339-340) rather than tagging everything and stripping backward tags after. This addresses @SherlockNoMad's core feedback about avoiding adhoc point fixes.
  • The graph_ac_pass wrapper follows the standard pass signature (gm, example_inputs) -> gm
  • apply_ac_on_fwd_bwd_graph is a clean composition: tag → remat
  • The per-region statistics logging is helpful for debugging

Issues:

  1. apply_sac_pass is in both AVAILABLE_JOINT_PASSES and used directly by graph_ac_passapply_sac_pass is registered in AVAILABLE_JOINT_PASSES (line 639) for the AOT mode path, and also called directly by apply_ac_on_fwd_bwd_graph for the aot_fx_trace path. These are two different contexts (joint graph before partition vs. fwd+bwd graph after make_fx trace). The function works in both, but it's worth a docstring note that clarifies this dual usage and any assumptions that differ.

  2. graph_ac_pass ignores example_inputs — The wrapper at lines 116-120 takes example_inputs but doesn't pass it through to apply_ac_on_fwd_bwd_graph. This is fine because SAC doesn't need inputs, but it breaks the pattern if someone expects all passes to use example_inputs. A comment noting this would help.

  3. The AC pass ordering is implicitgraph_ac_pass is always first (lines 133-134 of trainer.py), before the default passes from construct_default_graph_passes. This ordering matters (SAC must happen before cudagraph), but it's not documented or enforced. If someone reorders the passes, SAC could break silently.

  4. enable_graph_ac_for_mode duplicates validation — This function exists in common_utils.py:158-167 and is also used by apply_graph_ac at line 179. The apply_graph_ac function is for the AOT path. For aot_fx_trace, you call enable_graph_ac_for_mode directly in trainer.py. Having two separate entry points for "should we do graph AC?" increases the chance of them diverging.


Summary

Question Verdict
Q1: SAC peak test minimal? Yes, with minor cleanups (GiB label, shared _build_trainer helper)
Q2: Torchtitan coding format? Mostly — main gaps: hasattr guard, tensorboard top-level import, inline pass application
Q3: Backward tagging? Good design. custom["phase"] is more general than previous remat_pass_tag. Confirm with reviewer if naming is sufficient
Q4: SAC pass plumbing? Correct architecture. The skip-backward-at-source approach addresses the core reviewer feedback. Minor: dual apply_sac_pass usage should be documented, pass ordering is implicit

Top 3 actionable items:

  1. Move the tensorboard import in run_tests.py to be lazy (inside _read_peak_memory) to avoid breaking non-memory tests
  2. Remove the hasattr(self.config, "activation_checkpoint") guard — either access the field directly or fix test mocks
  3. Consider extracting a shared _build_trainer test helper to avoid drift between test_bitwise_deterministic.py and test_sac_peak_memory.py

… traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between aot and aot_fx_trace 

<img width="918" height="750" alt="Screenshot 2026-04-12 at 11 39 01 PM" src="https://github.com/user-attachments/assets/059188c5-fc14-4a34-876a-966d938bef3b" />


[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 13, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: ec18b1d
Pull Request resolved: #2766
… traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between aot and aot_fx_trace 

<img width="918" height="750" alt="Screenshot 2026-04-12 at 11 39 01 PM" src="https://github.com/user-attachments/assets/059188c5-fc14-4a34-876a-966d938bef3b" />


[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 13, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 7561d24
Pull Request resolved: #2766
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants