[GraphTrainer] Support transformer_block_bucketing in aot_fx_trace mode#2934
Draft
yiming0416 wants to merge 1 commit intomainfrom
Draft
[GraphTrainer] Support transformer_block_bucketing in aot_fx_trace mode#2934yiming0416 wants to merge 1 commit intomainfrom
yiming0416 wants to merge 1 commit intomainfrom
Conversation
e7a8559 to
3d07026
Compare
| custom = fwd_node.meta.get("custom") | ||
| if custom: | ||
| node.meta.setdefault("custom", {}).update(custom) | ||
| node.meta.setdefault("custom", {})[_IS_BWD] = True |
Contributor
There was a problem hiding this comment.
we won't need to do this when @tugsbayasgalan land the change in the upstream pytorch.
backward graph would then have a standardized metadata tag.
| return result | ||
|
|
||
|
|
||
| def annotate_module_fqns(model: nn.Module) -> None: |
Contributor
There was a problem hiding this comment.
cc @yushangdi, who is also adding the annotation for module path
There was a problem hiding this comment.
thx! I can probably just re-use this.
63e130f to
89b5762
Compare
5138509 to
aeadcae
Compare
aot_fx_trace mode traces with record_module_stack=False, so nodes lack nn_module_stack metadata that manual_overlap_bucketing uses to match nodes to modules. This change uses the existing annotate_fn mechanism (already used for AC region tagging) to tag module FQNs during tracing, then provides a custom module_stack_fn to manual_overlap_bucketing. Changes: - Add annotate_module_fqns() and get_module_stack_from_annotation() to common_utils.py - Call annotate_module_fqns() in both llama3 and deepseek_v3 annotate functions - Add module_stack_fn parameter to transformer_block_bucketing_reordering_pass - Update construct_default_graph_passes to accept compile_config, model, and parallel_dims, and wire up the bucketing pass with reassign_to_pg - Pass config/model/parallel_dims from GraphTrainer to construct_default_graph_passes - Add numerics tests for both llama3 and deepseek_v3 with aot_fx_trace + transformer_block_bucketing
aeadcae to
d0591e7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enable
transformer_block_bucketingpass inaot_fx_tracecompile mode. Previously this only worked inaotandjitmodes becauseaot_fx_tracetraces a joint fwd+bwd graph wherenn_module_stackmetadata alone isn't sufficient to correctly identify and separate forward vs backward nodes for bucketing.Root cause: In the joint graph produced by
aot_fx_trace, forward and backward collectives coexist. The existingmanual_overlap_bucketingpass needs to bucket each direction independently — forward collectives in forward execution order, backward collectives in reverse order — but had no way to distinguish them.Solution: Leverage the existing
annotate_fnmechanism (already used for AC region tagging) to:annotate_module_fqns)is_bwd=Trueduring metadata propagationmodule_stack_fntomanual_overlap_bucketingthat filters by fwd/bwdA new
joint_transformer_block_bucketing_reordering_passapplies bucketing in two passes — once for forward nodes, once for backward — so each direction's collectives are bucketed and reordered independently.