Skip to content

[GraphTrainer] Support transformer_block_bucketing in aot_fx_trace mode#2934

Draft
yiming0416 wants to merge 1 commit intomainfrom
graph_trainer/transformer-block-bucketing-aot-fx-trace
Draft

[GraphTrainer] Support transformer_block_bucketing in aot_fx_trace mode#2934
yiming0416 wants to merge 1 commit intomainfrom
graph_trainer/transformer-block-bucketing-aot-fx-trace

Conversation

@yiming0416
Copy link
Copy Markdown
Contributor

@yiming0416 yiming0416 commented Apr 10, 2026

Summary

Enable transformer_block_bucketing pass in aot_fx_trace compile mode. Previously this only worked in aot and jit modes because aot_fx_trace traces a joint fwd+bwd graph where nn_module_stack metadata alone isn't sufficient to correctly identify and separate forward vs backward nodes for bucketing.

Root cause: In the joint graph produced by aot_fx_trace, forward and backward collectives coexist. The existing manual_overlap_bucketing pass needs to bucket each direction independently — forward collectives in forward execution order, backward collectives in reverse order — but had no way to distinguish them.

Solution: Leverage the existing annotate_fn mechanism (already used for AC region tagging) to:

  1. Tag each bucket-level submodule's forward with its FQN (annotate_module_fqns)
  2. Mark backward nodes with is_bwd=True during metadata propagation
  3. Provide a direction-aware module_stack_fn to manual_overlap_bucketing that filters by fwd/bwd

A new joint_transformer_block_bucketing_reordering_pass applies bucketing in two passes — once for forward nodes, once for backward — so each direction's collectives are bucketed and reordered independently.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 10, 2026
@yiming0416 yiming0416 force-pushed the graph_trainer/transformer-block-bucketing-aot-fx-trace branch 4 times, most recently from e7a8559 to 3d07026 Compare April 10, 2026 22:42
custom = fwd_node.meta.get("custom")
if custom:
node.meta.setdefault("custom", {}).update(custom)
node.meta.setdefault("custom", {})[_IS_BWD] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we won't need to do this when @tugsbayasgalan land the change in the upstream pytorch.
backward graph would then have a standardized metadata tag.

return result


def annotate_module_fqns(model: nn.Module) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yushangdi, who is also adding the annotation for module path

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx! I can probably just re-use this.

@yiming0416 yiming0416 force-pushed the graph_trainer/transformer-block-bucketing-aot-fx-trace branch 2 times, most recently from 63e130f to 89b5762 Compare April 13, 2026 21:16
@yiming0416 yiming0416 changed the title [GraphTrainer][AutoDev] Support transformer_block_bucketing in aot_fx_trace mode [GraphTrainer] Support transformer_block_bucketing in aot_fx_trace mode Apr 13, 2026
@yiming0416 yiming0416 force-pushed the graph_trainer/transformer-block-bucketing-aot-fx-trace branch 2 times, most recently from 5138509 to aeadcae Compare April 13, 2026 21:38
aot_fx_trace mode traces with record_module_stack=False, so nodes lack
nn_module_stack metadata that manual_overlap_bucketing uses to match
nodes to modules. This change uses the existing annotate_fn mechanism
(already used for AC region tagging) to tag module FQNs during tracing,
then provides a custom module_stack_fn to manual_overlap_bucketing.

Changes:
- Add annotate_module_fqns() and get_module_stack_from_annotation() to
  common_utils.py
- Call annotate_module_fqns() in both llama3 and deepseek_v3 annotate
  functions
- Add module_stack_fn parameter to
  transformer_block_bucketing_reordering_pass
- Update construct_default_graph_passes to accept compile_config, model,
  and parallel_dims, and wire up the bucketing pass with reassign_to_pg
- Pass config/model/parallel_dims from GraphTrainer to
  construct_default_graph_passes
- Add numerics tests for both llama3 and deepseek_v3 with aot_fx_trace
  + transformer_block_bucketing
@yiming0416 yiming0416 force-pushed the graph_trainer/transformer-block-bucketing-aot-fx-trace branch from aeadcae to d0591e7 Compare April 13, 2026 21:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants