[graph_trainer] Add CooR precompile support for DeepSeek V3#2916
Open
bobrenjc93 wants to merge 3 commits intogh/bobrenjc93/39/basefrom
Open
[graph_trainer] Add CooR precompile support for DeepSeek V3#2916bobrenjc93 wants to merge 3 commits intogh/bobrenjc93/39/basefrom
bobrenjc93 wants to merge 3 commits intogh/bobrenjc93/39/basefrom
Conversation
bobrenjc93
added a commit
that referenced
this pull request
Apr 9, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).
Key changes:
1. Pass DeviceMesh instead of ProcessGroup to collectives
(expert_parallel.py):
- Change all_to_all_single/all_to_all_single_autograd calls to
pass `device_mesh` directly instead of `device_mesh.get_group()`.
This lets PyTorch's CooR-aware `_resolve_group` dispatch path
derive ProcessGroups from DeviceMesh inputs via the
`mesh_get_process_group` custom op, making the graph
rank-agnostic without needing custom PG pickling.
2. MoE guard error handling (precompile_main.py):
- Wrap the forward pass in try/except since MoE expert routing
produces dynamic token counts via all_to_all_single, whose
Inductor shape guards fail on fake tensors
- Only suppress the error when the on_compile callback ran
successfully (compile_succeeded flag), not just when the
artifact file exists on disk
3. ReordererSequenceParallel CooR fix (expert_parallel.py):
- Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
the rank coordinate is a SymInt under compile_on_one_rank.
Without this, get_local_rank() returns a concrete int that gets
baked into the FX graph, causing CooR precompile to use rank 0's
token slice indices on all ranks.
4. Extract build_forward_extra_kwargs() (forward_utils.py):
- Factor the extra_kwargs construction (positions, attention_masks)
out of Trainer.post_dataloading_process() into a standalone
function in torchtitan/components/forward_utils.py.
- Trainer, Validator, and precompile_main all call this shared
function, eliminating duplicated config-inspection logic.
- Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
structurally-matching masks, not real values.
- Resolves the TODO in validate.py about deduplicating with
Trainer.post_dataloading_process.
5. Test + docs (run_precompile_tests.py, README.md):
- Add DSv3 precompile test definitions with EP=4, ETP=1
- Add DSv3 precompile example commands to README
- Clarify TODO comments: DSv3 cudagraph limitation affects both
precompile and plain AOT (MoE routing CPU tensor ops)
Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.
E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):
DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
step 1: loss=7.94874 grad_norm=3.7722
step 2: loss=6.06074 grad_norm=4.1036
step 3: loss=5.01055 grad_norm=3.2918
step 4: loss=4.91436 grad_norm=2.7785
step 5: loss=4.67803 grad_norm=2.7366
Result: bitwise identical baseline vs precompile
MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)
MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS
ghstack-source-id: 6e0d7fc
Pull-Request: #2916
bobrenjc93
added a commit
that referenced
this pull request
Apr 9, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).
Key changes:
1. Pass DeviceMesh instead of ProcessGroup to collectives
(expert_parallel.py):
- Change all_to_all_single/all_to_all_single_autograd calls to
pass `device_mesh` directly instead of `device_mesh.get_group()`.
This lets PyTorch's CooR-aware `_resolve_group` dispatch path
derive ProcessGroups from DeviceMesh inputs via the
`mesh_get_process_group` custom op, making the graph
rank-agnostic without needing custom PG pickling.
2. MoE guard error handling (precompile_main.py):
- Wrap the forward pass in try/except since MoE expert routing
produces dynamic token counts via all_to_all_single, whose
Inductor shape guards fail on fake tensors
- Only suppress the error when the on_compile callback ran
successfully (compile_succeeded flag), not just when the
artifact file exists on disk
3. ReordererSequenceParallel CooR fix (expert_parallel.py):
- Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
the rank coordinate is a SymInt under compile_on_one_rank.
Without this, get_local_rank() returns a concrete int that gets
baked into the FX graph, causing CooR precompile to use rank 0's
token slice indices on all ranks.
4. Extract build_forward_extra_kwargs() (forward_utils.py):
- Factor the extra_kwargs construction (positions, attention_masks)
out of Trainer.post_dataloading_process() into a standalone
function in torchtitan/components/forward_utils.py.
- Trainer, Validator, and precompile_main all call this shared
function, eliminating duplicated config-inspection logic.
- Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
structurally-matching masks, not real values.
- Resolves the TODO in validate.py about deduplicating with
Trainer.post_dataloading_process.
5. Test + docs (run_precompile_tests.py, README.md):
- Add DSv3 precompile test definitions with EP=4, ETP=1
- Add DSv3 precompile example commands to README
- Clarify TODO comments: DSv3 cudagraph limitation affects both
precompile and plain AOT (MoE routing CPU tensor ops)
Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.
E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):
DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
step 1: loss=7.94874 grad_norm=3.7722
step 2: loss=6.06074 grad_norm=4.1036
step 3: loss=5.01055 grad_norm=3.2918
step 4: loss=4.91436 grad_norm=2.7785
step 5: loss=4.67803 grad_norm=2.7366
Result: bitwise identical baseline vs precompile
MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)
MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS
ghstack-source-id: 6e0d7fc
Pull-Request: #2916
bobrenjc93
added a commit
that referenced
this pull request
Apr 12, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).
Key changes:
1. Pass DeviceMesh instead of ProcessGroup to collectives
(expert_parallel.py):
- Change all_to_all_single/all_to_all_single_autograd calls to
pass `device_mesh` directly instead of `device_mesh.get_group()`.
This lets PyTorch's CooR-aware `_resolve_group` dispatch path
derive ProcessGroups from DeviceMesh inputs via the
`mesh_get_process_group` custom op, making the graph
rank-agnostic without needing custom PG pickling.
2. MoE guard error handling (precompile_main.py):
- Wrap the forward pass in try/except since MoE expert routing
produces dynamic token counts via all_to_all_single, whose
Inductor shape guards fail on fake tensors
- Only suppress the error when the on_compile callback ran
successfully (compile_succeeded flag), not just when the
artifact file exists on disk
3. ReordererSequenceParallel CooR fix (expert_parallel.py):
- Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
the rank coordinate is a SymInt under compile_on_one_rank.
Without this, get_local_rank() returns a concrete int that gets
baked into the FX graph, causing CooR precompile to use rank 0's
token slice indices on all ranks.
4. Extract build_forward_extra_kwargs() (forward_utils.py):
- Factor the extra_kwargs construction (positions, attention_masks)
out of Trainer.post_dataloading_process() into a standalone
function in torchtitan/components/forward_utils.py.
- Trainer, Validator, and precompile_main all call this shared
function, eliminating duplicated config-inspection logic.
- Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
structurally-matching masks, not real values.
- Resolves the TODO in validate.py about deduplicating with
Trainer.post_dataloading_process.
5. Test + docs (run_precompile_tests.py, README.md):
- Add DSv3 precompile test definitions with EP=4, ETP=1
- Add DSv3 precompile example commands to README
- Clarify TODO comments: DSv3 cudagraph limitation affects both
precompile and plain AOT (MoE routing CPU tensor ops)
Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.
E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):
DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
step 1: loss=7.94874 grad_norm=3.7722
step 2: loss=6.06074 grad_norm=4.1036
step 3: loss=5.01055 grad_norm=3.2918
step 4: loss=4.91436 grad_norm=2.7785
step 5: loss=4.67803 grad_norm=2.7366
Result: bitwise identical baseline vs precompile
MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)
MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS
ghstack-source-id: 5aaa127
Pull-Request: #2916
bobrenjc93
added a commit
that referenced
this pull request
Apr 12, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).
Key changes:
1. Pass DeviceMesh instead of ProcessGroup to collectives
(expert_parallel.py):
- Change all_to_all_single/all_to_all_single_autograd calls to
pass `device_mesh` directly instead of `device_mesh.get_group()`.
This lets PyTorch's CooR-aware `_resolve_group` dispatch path
derive ProcessGroups from DeviceMesh inputs via the
`mesh_get_process_group` custom op, making the graph
rank-agnostic without needing custom PG pickling.
2. MoE guard error handling (precompile_main.py):
- Wrap the forward pass in try/except since MoE expert routing
produces dynamic token counts via all_to_all_single, whose
Inductor shape guards fail on fake tensors
- Only suppress the error when the on_compile callback ran
successfully (compile_succeeded flag), not just when the
artifact file exists on disk
3. ReordererSequenceParallel CooR fix (expert_parallel.py):
- Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
the rank coordinate is a SymInt under compile_on_one_rank.
Without this, get_local_rank() returns a concrete int that gets
baked into the FX graph, causing CooR precompile to use rank 0's
token slice indices on all ranks.
4. Extract build_forward_extra_kwargs() (forward_utils.py):
- Factor the extra_kwargs construction (positions, attention_masks)
out of Trainer.post_dataloading_process() into a standalone
function in torchtitan/components/forward_utils.py.
- Trainer, Validator, and precompile_main all call this shared
function, eliminating duplicated config-inspection logic.
- Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
structurally-matching masks, not real values.
- Resolves the TODO in validate.py about deduplicating with
Trainer.post_dataloading_process.
5. Test + docs (run_precompile_tests.py, README.md):
- Add DSv3 precompile test definitions with EP=4, ETP=1
- Add DSv3 precompile example commands to README
- Clarify TODO comments: DSv3 cudagraph limitation affects both
precompile and plain AOT (MoE routing CPU tensor ops)
Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.
E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):
DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
step 1: loss=7.94874 grad_norm=3.7722
step 2: loss=6.06074 grad_norm=4.1036
step 3: loss=5.01055 grad_norm=3.2918
step 4: loss=4.91436 grad_norm=2.7785
step 5: loss=4.67803 grad_norm=2.7366
Result: bitwise identical baseline vs precompile
MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)
MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS
ghstack-source-id: b5d3593
Pull-Request: #2916
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).
Key changes:
Pass DeviceMesh instead of ProcessGroup to collectives
(expert_parallel.py):
pass
device_meshdirectly instead ofdevice_mesh.get_group().This lets PyTorch's CooR-aware
_resolve_groupdispatch pathderive ProcessGroups from DeviceMesh inputs via the
mesh_get_process_groupcustom op, making the graphrank-agnostic without needing custom PG pickling.
MoE guard error handling (precompile_main.py):
produces dynamic token counts via all_to_all_single, whose
Inductor shape guards fail on fake tensors
successfully (compile_succeeded flag), not just when the
artifact file exists on disk
ReordererSequenceParallel CooR fix (expert_parallel.py):
_sym_get_coordinate(0)instead ofget_local_rank()sothe rank coordinate is a SymInt under compile_on_one_rank.
Without this, get_local_rank() returns a concrete int that gets
baked into the FX graph, causing CooR precompile to use rank 0's
token slice indices on all ranks.
Extract build_forward_extra_kwargs() (forward_utils.py):
out of Trainer.post_dataloading_process() into a standalone
function in torchtitan/components/forward_utils.py.
function, eliminating duplicated config-inspection logic.
structurally-matching masks, not real values.
Trainer.post_dataloading_process.
Test + docs (run_precompile_tests.py, README.md):
precompile and plain AOT (MoE routing CPU tensor ops)
Note: Also requires a PyTorch fix to include
torchbind_constants(which contain the ProcessGroup) in
CompiledFxGraphConstants.unwrap()so they are injected into the loaded Inductor module during
deserialization.
E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):
DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
step 1: loss=7.94874 grad_norm=3.7722
step 2: loss=6.06074 grad_norm=4.1036
step 3: loss=5.01055 grad_norm=3.2918
step 4: loss=4.91436 grad_norm=2.7785
step 5: loss=4.67803 grad_norm=2.7366
Result: bitwise identical baseline vs precompile
MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)
MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS