Skip to content

Conversation

@tomjen12
Copy link

Motivation

When running distributed training (DDP) with torch.compile enabled, the training crashes during the AOTAutograd graph partitioning phase (min_cut_rematerialization_partition).

The error indicates that a tensor view operation in the backward graph is invalid relative to the graph partition boundary:

torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
AssertionError: Node view_21 was invalid, but is output

Technical Details

The _flash_attn_backward op mutates gradients (dq, dk, dv) in-place. Previously, these side effects were not registered, causing AOTAutograd to misinterpret aliasing relationships (e.g., view ops on gradients) as invalid nodes.

I updated the @torch_compile_guard to explicitly include mutates_args=["dq", "dk", "dv"], allowing correct schema inference and graph partitioning.

@torch_compile_guard(mutates_args=["dq", "dk", "dv"], gen_fake=_flash_attn_backward_fake)

Test Plan

Run a training loop on a GPT-3 6.7B model using DDP and torch.compile(model)

Test Result

Before Fix: The script crashes immediately during the first backward pass with AssertionError: Node view_21 was invalid.

After Fix: The graph compiles successfully, and the training loop proceeds without errors. Validated on a gpt-3-6.7B model.

smci355-ccs-aus-m01-17[0:00/0][2025-12-22 20:37:48,037] [train.py:129] train_step: 0, loss=0.00033572062966413796, iter_dt=146.44924759864807, fps_gpu=0.013656608229774632, fps_tot=0.10925286583819706
smci355-ccs-aus-m01-17[0:00/0][2025-12-22 20:37:51,171] [train.py:129] train_step: 1, loss=25.630146026611328, iter_dt=0.5292143821716309, fps_gpu=3.7791867858787236, fps_tot=30.23349428702979

Submission Checklist

@tomjen12 tomjen12 requested review from a team and ZhangLirong-amd December 23, 2025 05:10
Copilot AI review requested due to automatic review settings December 29, 2025 02:10
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a crash in distributed training (DDP) with torch.compile enabled by adding mutation annotations to the _flash_attn_backward function. The issue occurred during AOTAutograd graph partitioning when the compiler couldn't properly track that certain gradient tensors were being mutated in-place.

  • Adds mutates_args=["dq", "dk", "dv"] to the @torch_compile_guard decorator for _flash_attn_backward
  • Enables correct schema inference and graph partitioning for AOTAutograd
  • Resolves the "Node view_21 was invalid, but is output" assertion error

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants