[Fix] Add mutates_args to flash_attn_backward to fix AOTAutograd DDP … #1712
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
When running distributed training (DDP) with torch.compile enabled, the training crashes during the AOTAutograd graph partitioning phase (min_cut_rematerialization_partition).
The error indicates that a tensor view operation in the backward graph is invalid relative to the graph partition boundary:
Technical Details
The _flash_attn_backward op mutates gradients (dq, dk, dv) in-place. Previously, these side effects were not registered, causing AOTAutograd to misinterpret aliasing relationships (e.g., view ops on gradients) as invalid nodes.
I updated the @torch_compile_guard to explicitly include mutates_args=["dq", "dk", "dv"], allowing correct schema inference and graph partitioning.
@torch_compile_guard(mutates_args=["dq", "dk", "dv"], gen_fake=_flash_attn_backward_fake)Test Plan
Run a training loop on a GPT-3 6.7B model using DDP and torch.compile(model)
Test Result
Before Fix: The script crashes immediately during the first backward pass with AssertionError: Node view_21 was invalid.
After Fix: The graph compiles successfully, and the training loop proceeds without errors. Validated on a gpt-3-6.7B model.
Submission Checklist