Skip to content

[GraphTrainer] Changed tagging flex_attn via graph_pass in replace of fx.annotation for better robustness #2924

Merged
SherlockNoMad merged 4 commits intomainfrom
graph_trainer/annotate_flex_for_regional_inductor
Apr 10, 2026
Merged

[GraphTrainer] Changed tagging flex_attn via graph_pass in replace of fx.annotation for better robustness #2924
SherlockNoMad merged 4 commits intomainfrom
graph_trainer/annotate_flex_for_regional_inductor

Conversation

@SherlockNoMad
Copy link
Copy Markdown
Contributor

@SherlockNoMad SherlockNoMad commented Apr 9, 2026

Summary

Moves flex attention annotation from a pre-tracing function/context-manager
(annotate_flex_for_regional_inductor / annotate_flex_attention_for_regional_inductor
in common_utils.py) to a post-tracing graph pass
(annotate_flex_attention_for_regional_inductor_pass in passes.py).

Why: The previous approach annotated Python-level functions before tracing,
which required a context manager to temporarily patch and restore
FlexAttention._compiled_flex_attn and _compiled_create_block_mask.
A graph pass is simpler — it directly tags the relevant FX nodes after tracing,
with no monkey-patching or cleanup needed.

What the pass does: Annotates three sets of nodes with
compile_with_inductor (including inductor_configs from FlexAttention)
so that regional_inductor correctly scoops and compiles flex attention regions:

  1. The HOP nodes (flex_attention / flex_attention_backward)
  2. The get_attr nodes referencing score_mod / mask_mod submodules
  3. All nodes inside those submodule graphs

Changes:

  • Add annotate_flex_attention_for_regional_inductor_pass graph pass in passes.py
  • Remove annotate_flex_for_regional_inductor() and its context manager from common_utils.py
  • Remove pre-tracing annotation calls from llama3/parallelize.py and deepseek_v3/parallelize.py
  • Wire up the pass in graph_utils.py (applied as a joint pass before regional_inductor)
  • Update tests to use the graph pass instead of the context manager

Test plan

  • test_passes.py — passed
  • test_precompile.py — passed
  • test_trace_module.py — 27/28 passed (1 pre-existing failure in test_peak_memory_identical_fsdp)
  • test_numerics.py — passed
  • test_bitwise_deterministic.py — passed
  • pre-commit run --all-files — passed

…l_inductor

Refactor flex attention annotations to tag `_compiled_flex_attn` and
`_compiled_create_block_mask` with `compile_with_inductor` metadata
including inductor configs, instead of annotating `FlexAttention.forward`.

This ensures bitwise-identical kernels between eager and regional_inductor
paths by propagating the same inductor configs used by
`FlexAttention._compiled_flex_attn`.

- Add `annotate_flex_for_regional_inductor()` for permanent annotations
- Update context manager to use the new function and restore originals
- Unify llama3 and deepseek_v3 parallelize to use the shared helper
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
)
MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward)

FlexAttention.forward = annotate_fn({"compile_with_inductor": "flex_attention"})(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to change in #2761

wrapping over entire FlexAttention.forward is annotating more than it's supposed to be compiled.

Thus this fix.

{"compile_with_inductor": "flex_attention"} so the compiler can apply
regional inductor pass based on the annotation. Regional inductor is now only
supported in AOT mode.
- Flex attention annotation: Tags FlexAttention.forward and compiled flex
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

@SherlockNoMad SherlockNoMad force-pushed the graph_trainer/annotate_flex_for_regional_inductor branch from d9baee0 to c47515f Compare April 10, 2026 06:35
@SherlockNoMad SherlockNoMad changed the title [GraphTrainer] Annotate compiled flex attention functions for regional_inductor [GraphTrainer] annotate_flex_attention_for_regional_inductor_pass Apr 10, 2026
@SherlockNoMad SherlockNoMad marked this pull request as ready for review April 10, 2026 06:57
@SherlockNoMad SherlockNoMad changed the title [GraphTrainer] annotate_flex_attention_for_regional_inductor_pass [GraphTrainer] Changed tagging flex_attn via graph_pass in replace of fx.annotation for better robustness Apr 10, 2026
@SherlockNoMad SherlockNoMad merged commit e24e465 into main Apr 10, 2026
17 of 25 checks passed
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
… fx.annotation for better robustness (pytorch#2924)

## Summary

Moves flex attention annotation from a pre-tracing
function/context-manager
(`annotate_flex_for_regional_inductor` /
`annotate_flex_attention_for_regional_inductor`
in `common_utils.py`) to a post-tracing graph pass
(`annotate_flex_attention_for_regional_inductor_pass` in `passes.py`).

**Why:** The previous approach annotated Python-level functions before
tracing,
which required a context manager to temporarily patch and restore
`FlexAttention._compiled_flex_attn` and `_compiled_create_block_mask`.
A graph pass is simpler — it directly tags the relevant FX nodes after
tracing,
with no monkey-patching or cleanup needed.

**What the pass does:** Annotates three sets of nodes with
`compile_with_inductor` (including `inductor_configs` from
`FlexAttention`)
so that `regional_inductor` correctly scoops and compiles flex attention
regions:
1. The HOP nodes (`flex_attention` / `flex_attention_backward`)
2. The `get_attr` nodes referencing score_mod / mask_mod submodules
3. All nodes inside those submodule graphs

**Changes:**
- Add `annotate_flex_attention_for_regional_inductor_pass` graph pass in
`passes.py`
- Remove `annotate_flex_for_regional_inductor()` and its context manager
from `common_utils.py`
- Remove pre-tracing annotation calls from `llama3/parallelize.py` and
`deepseek_v3/parallelize.py`
- Wire up the pass in `graph_utils.py` (applied as a joint pass before
regional_inductor)
- Update tests to use the graph pass instead of the context manager

## Test plan
- [x] `test_passes.py` — passed
- [x] `test_precompile.py` — passed
- [x] `test_trace_module.py` — 27/28 passed (1 pre-existing failure in
`test_peak_memory_identical_fsdp`)
- [x] `test_numerics.py` — passed
- [x] `test_bitwise_deterministic.py` — passed
- [x] `pre-commit run --all-files` — passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants