[GraphTrainer] Changed tagging flex_attn via graph_pass in replace of fx.annotation for better robustness #2924
Merged
SherlockNoMad merged 4 commits intomainfrom Apr 10, 2026
Conversation
…l_inductor Refactor flex attention annotations to tag `_compiled_flex_attn` and `_compiled_create_block_mask` with `compile_with_inductor` metadata including inductor configs, instead of annotating `FlexAttention.forward`. This ensures bitwise-identical kernels between eager and regional_inductor paths by propagating the same inductor configs used by `FlexAttention._compiled_flex_attn`. - Add `annotate_flex_for_regional_inductor()` for permanent annotations - Update context manager to use the new function and restore originals - Unify llama3 and deepseek_v3 parallelize to use the shared helper
SherlockNoMad
commented
Apr 10, 2026
| ) | ||
| MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward) | ||
|
|
||
| FlexAttention.forward = annotate_fn({"compile_with_inductor": "flex_attention"})( |
Contributor
Author
There was a problem hiding this comment.
Due to change in #2761
wrapping over entire FlexAttention.forward is annotating more than it's supposed to be compiled.
Thus this fix.
SherlockNoMad
commented
Apr 10, 2026
| {"compile_with_inductor": "flex_attention"} so the compiler can apply | ||
| regional inductor pass based on the annotation. Regional inductor is now only | ||
| supported in AOT mode. | ||
| - Flex attention annotation: Tags FlexAttention.forward and compiled flex |
…flex_attention_for_regional_inductor_pass
d9baee0 to
c47515f
Compare
aditvenk
approved these changes
Apr 10, 2026
yiming0416
approved these changes
Apr 10, 2026
This was referenced Apr 10, 2026
TXacs
pushed a commit
to McmillanTAC/torchtitan
that referenced
this pull request
Apr 13, 2026
… fx.annotation for better robustness (pytorch#2924) ## Summary Moves flex attention annotation from a pre-tracing function/context-manager (`annotate_flex_for_regional_inductor` / `annotate_flex_attention_for_regional_inductor` in `common_utils.py`) to a post-tracing graph pass (`annotate_flex_attention_for_regional_inductor_pass` in `passes.py`). **Why:** The previous approach annotated Python-level functions before tracing, which required a context manager to temporarily patch and restore `FlexAttention._compiled_flex_attn` and `_compiled_create_block_mask`. A graph pass is simpler — it directly tags the relevant FX nodes after tracing, with no monkey-patching or cleanup needed. **What the pass does:** Annotates three sets of nodes with `compile_with_inductor` (including `inductor_configs` from `FlexAttention`) so that `regional_inductor` correctly scoops and compiles flex attention regions: 1. The HOP nodes (`flex_attention` / `flex_attention_backward`) 2. The `get_attr` nodes referencing score_mod / mask_mod submodules 3. All nodes inside those submodule graphs **Changes:** - Add `annotate_flex_attention_for_regional_inductor_pass` graph pass in `passes.py` - Remove `annotate_flex_for_regional_inductor()` and its context manager from `common_utils.py` - Remove pre-tracing annotation calls from `llama3/parallelize.py` and `deepseek_v3/parallelize.py` - Wire up the pass in `graph_utils.py` (applied as a joint pass before regional_inductor) - Update tests to use the graph pass instead of the context manager ## Test plan - [x] `test_passes.py` — passed - [x] `test_precompile.py` — passed - [x] `test_trace_module.py` — 27/28 passed (1 pre-existing failure in `test_peak_memory_identical_fsdp`) - [x] `test_numerics.py` — passed - [x] `test_bitwise_deterministic.py` — passed - [x] `pre-commit run --all-files` — passed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Moves flex attention annotation from a pre-tracing function/context-manager
(
annotate_flex_for_regional_inductor/annotate_flex_attention_for_regional_inductorin
common_utils.py) to a post-tracing graph pass(
annotate_flex_attention_for_regional_inductor_passinpasses.py).Why: The previous approach annotated Python-level functions before tracing,
which required a context manager to temporarily patch and restore
FlexAttention._compiled_flex_attnand_compiled_create_block_mask.A graph pass is simpler — it directly tags the relevant FX nodes after tracing,
with no monkey-patching or cleanup needed.
What the pass does: Annotates three sets of nodes with
compile_with_inductor(includinginductor_configsfromFlexAttention)so that
regional_inductorcorrectly scoops and compiles flex attention regions:flex_attention/flex_attention_backward)get_attrnodes referencing score_mod / mask_mod submodulesChanges:
annotate_flex_attention_for_regional_inductor_passgraph pass inpasses.pyannotate_flex_for_regional_inductor()and its context manager fromcommon_utils.pyllama3/parallelize.pyanddeepseek_v3/parallelize.pygraph_utils.py(applied as a joint pass before regional_inductor)Test plan
test_passes.py— passedtest_precompile.py— passedtest_trace_module.py— 27/28 passed (1 pre-existing failure intest_peak_memory_identical_fsdp)test_numerics.py— passedtest_bitwise_deterministic.py— passedpre-commit run --all-files— passed