Skip to content

[graph_trainer] Add CooR precompile support for DeepSeek V3#2916

Open
bobrenjc93 wants to merge 3 commits intogh/bobrenjc93/39/basefrom
gh/bobrenjc93/39/head
Open

[graph_trainer] Add CooR precompile support for DeepSeek V3#2916
bobrenjc93 wants to merge 3 commits intogh/bobrenjc93/39/basefrom
gh/bobrenjc93/39/head

Conversation

@bobrenjc93
Copy link
Copy Markdown
Contributor

@bobrenjc93 bobrenjc93 commented Apr 9, 2026

Stack from ghstack (oldest at bottom):

Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).

Key changes:

  1. Pass DeviceMesh instead of ProcessGroup to collectives
    (expert_parallel.py):

    • Change all_to_all_single/all_to_all_single_autograd calls to
      pass device_mesh directly instead of device_mesh.get_group().
      This lets PyTorch's CooR-aware _resolve_group dispatch path
      derive ProcessGroups from DeviceMesh inputs via the
      mesh_get_process_group custom op, making the graph
      rank-agnostic without needing custom PG pickling.
  2. MoE guard error handling (precompile_main.py):

    • Wrap the forward pass in try/except since MoE expert routing
      produces dynamic token counts via all_to_all_single, whose
      Inductor shape guards fail on fake tensors
    • Only suppress the error when the on_compile callback ran
      successfully (compile_succeeded flag), not just when the
      artifact file exists on disk
  3. ReordererSequenceParallel CooR fix (expert_parallel.py):

    • Use _sym_get_coordinate(0) instead of get_local_rank() so
      the rank coordinate is a SymInt under compile_on_one_rank.
      Without this, get_local_rank() returns a concrete int that gets
      baked into the FX graph, causing CooR precompile to use rank 0's
      token slice indices on all ranks.
  4. Extract build_forward_extra_kwargs() (forward_utils.py):

    • Factor the extra_kwargs construction (positions, attention_masks)
      out of Trainer.post_dataloading_process() into a standalone
      function in torchtitan/components/forward_utils.py.
    • Trainer, Validator, and precompile_main all call this shared
      function, eliminating duplicated config-inspection logic.
    • Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
      structurally-matching masks, not real values.
    • Resolves the TODO in validate.py about deduplicating with
      Trainer.post_dataloading_process.
  5. Test + docs (run_precompile_tests.py, README.md):

    • Add DSv3 precompile test definitions with EP=4, ETP=1
    • Add DSv3 precompile example commands to README
    • Clarify TODO comments: DSv3 cudagraph limitation affects both
      precompile and plain AOT (MoE routing CPU tensor ops)

Note: Also requires a PyTorch fix to include torchbind_constants
(which contain the ProcessGroup) in CompiledFxGraphConstants.unwrap()
so they are injected into the loaded Inductor module during
deserialization.

E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):

DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
step 1: loss=7.94874 grad_norm=3.7722
step 2: loss=6.06074 grad_norm=4.1036
step 3: loss=5.01055 grad_norm=3.2918
step 4: loss=4.91436 grad_norm=2.7785
step 5: loss=4.67803 grad_norm=2.7366
Result: bitwise identical baseline vs precompile
MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)

MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Apr 9, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).

Key changes:

1. Pass DeviceMesh instead of ProcessGroup to collectives
   (expert_parallel.py):
   - Change all_to_all_single/all_to_all_single_autograd calls to
     pass `device_mesh` directly instead of `device_mesh.get_group()`.
     This lets PyTorch's CooR-aware `_resolve_group` dispatch path
     derive ProcessGroups from DeviceMesh inputs via the
     `mesh_get_process_group` custom op, making the graph
     rank-agnostic without needing custom PG pickling.

2. MoE guard error handling (precompile_main.py):
   - Wrap the forward pass in try/except since MoE expert routing
     produces dynamic token counts via all_to_all_single, whose
     Inductor shape guards fail on fake tensors
   - Only suppress the error when the on_compile callback ran
     successfully (compile_succeeded flag), not just when the
     artifact file exists on disk

3. ReordererSequenceParallel CooR fix (expert_parallel.py):
   - Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
     the rank coordinate is a SymInt under compile_on_one_rank.
     Without this, get_local_rank() returns a concrete int that gets
     baked into the FX graph, causing CooR precompile to use rank 0's
     token slice indices on all ranks.

4. Extract build_forward_extra_kwargs() (forward_utils.py):
   - Factor the extra_kwargs construction (positions, attention_masks)
     out of Trainer.post_dataloading_process() into a standalone
     function in torchtitan/components/forward_utils.py.
   - Trainer, Validator, and precompile_main all call this shared
     function, eliminating duplicated config-inspection logic.
   - Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
     structurally-matching masks, not real values.
   - Resolves the TODO in validate.py about deduplicating with
     Trainer.post_dataloading_process.

5. Test + docs (run_precompile_tests.py, README.md):
   - Add DSv3 precompile test definitions with EP=4, ETP=1
   - Add DSv3 precompile example commands to README
   - Clarify TODO comments: DSv3 cudagraph limitation affects both
     precompile and plain AOT (MoE routing CPU tensor ops)

Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.

E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):

  DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
    step 1: loss=7.94874 grad_norm=3.7722
    step 2: loss=6.06074 grad_norm=4.1036
    step 3: loss=5.01055 grad_norm=3.2918
    step 4: loss=4.91436 grad_norm=2.7785
    step 5: loss=4.67803 grad_norm=2.7366
  Result: bitwise identical baseline vs precompile
  MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)

MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS


ghstack-source-id: 6e0d7fc
Pull-Request: #2916
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
bobrenjc93 added a commit that referenced this pull request Apr 9, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).

Key changes:

1. Pass DeviceMesh instead of ProcessGroup to collectives
   (expert_parallel.py):
   - Change all_to_all_single/all_to_all_single_autograd calls to
     pass `device_mesh` directly instead of `device_mesh.get_group()`.
     This lets PyTorch's CooR-aware `_resolve_group` dispatch path
     derive ProcessGroups from DeviceMesh inputs via the
     `mesh_get_process_group` custom op, making the graph
     rank-agnostic without needing custom PG pickling.

2. MoE guard error handling (precompile_main.py):
   - Wrap the forward pass in try/except since MoE expert routing
     produces dynamic token counts via all_to_all_single, whose
     Inductor shape guards fail on fake tensors
   - Only suppress the error when the on_compile callback ran
     successfully (compile_succeeded flag), not just when the
     artifact file exists on disk

3. ReordererSequenceParallel CooR fix (expert_parallel.py):
   - Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
     the rank coordinate is a SymInt under compile_on_one_rank.
     Without this, get_local_rank() returns a concrete int that gets
     baked into the FX graph, causing CooR precompile to use rank 0's
     token slice indices on all ranks.

4. Extract build_forward_extra_kwargs() (forward_utils.py):
   - Factor the extra_kwargs construction (positions, attention_masks)
     out of Trainer.post_dataloading_process() into a standalone
     function in torchtitan/components/forward_utils.py.
   - Trainer, Validator, and precompile_main all call this shared
     function, eliminating duplicated config-inspection logic.
   - Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
     structurally-matching masks, not real values.
   - Resolves the TODO in validate.py about deduplicating with
     Trainer.post_dataloading_process.

5. Test + docs (run_precompile_tests.py, README.md):
   - Add DSv3 precompile test definitions with EP=4, ETP=1
   - Add DSv3 precompile example commands to README
   - Clarify TODO comments: DSv3 cudagraph limitation affects both
     precompile and plain AOT (MoE routing CPU tensor ops)

Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.

E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):

  DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
    step 1: loss=7.94874 grad_norm=3.7722
    step 2: loss=6.06074 grad_norm=4.1036
    step 3: loss=5.01055 grad_norm=3.2918
    step 4: loss=4.91436 grad_norm=2.7785
    step 5: loss=4.67803 grad_norm=2.7366
  Result: bitwise identical baseline vs precompile
  MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)

MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS


ghstack-source-id: 6e0d7fc
Pull-Request: #2916
@bobrenjc93 bobrenjc93 marked this pull request as ready for review April 9, 2026 20:18
@bobrenjc93 bobrenjc93 requested review from aorenste and zhxchen17 April 9, 2026 20:20
[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Apr 12, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).

Key changes:

1. Pass DeviceMesh instead of ProcessGroup to collectives
   (expert_parallel.py):
   - Change all_to_all_single/all_to_all_single_autograd calls to
     pass `device_mesh` directly instead of `device_mesh.get_group()`.
     This lets PyTorch's CooR-aware `_resolve_group` dispatch path
     derive ProcessGroups from DeviceMesh inputs via the
     `mesh_get_process_group` custom op, making the graph
     rank-agnostic without needing custom PG pickling.

2. MoE guard error handling (precompile_main.py):
   - Wrap the forward pass in try/except since MoE expert routing
     produces dynamic token counts via all_to_all_single, whose
     Inductor shape guards fail on fake tensors
   - Only suppress the error when the on_compile callback ran
     successfully (compile_succeeded flag), not just when the
     artifact file exists on disk

3. ReordererSequenceParallel CooR fix (expert_parallel.py):
   - Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
     the rank coordinate is a SymInt under compile_on_one_rank.
     Without this, get_local_rank() returns a concrete int that gets
     baked into the FX graph, causing CooR precompile to use rank 0's
     token slice indices on all ranks.

4. Extract build_forward_extra_kwargs() (forward_utils.py):
   - Factor the extra_kwargs construction (positions, attention_masks)
     out of Trainer.post_dataloading_process() into a standalone
     function in torchtitan/components/forward_utils.py.
   - Trainer, Validator, and precompile_main all call this shared
     function, eliminating duplicated config-inspection logic.
   - Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
     structurally-matching masks, not real values.
   - Resolves the TODO in validate.py about deduplicating with
     Trainer.post_dataloading_process.

5. Test + docs (run_precompile_tests.py, README.md):
   - Add DSv3 precompile test definitions with EP=4, ETP=1
   - Add DSv3 precompile example commands to README
   - Clarify TODO comments: DSv3 cudagraph limitation affects both
     precompile and plain AOT (MoE routing CPU tensor ops)

Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.

E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):

  DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
    step 1: loss=7.94874 grad_norm=3.7722
    step 2: loss=6.06074 grad_norm=4.1036
    step 3: loss=5.01055 grad_norm=3.2918
    step 4: loss=4.91436 grad_norm=2.7785
    step 5: loss=4.67803 grad_norm=2.7366
  Result: bitwise identical baseline vs precompile
  MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)

MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS

ghstack-source-id: 5aaa127
Pull-Request: #2916
[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Apr 12, 2026
Extend the CooR precompile infrastructure to support MoE models
with expert parallelism (EP).

Key changes:

1. Pass DeviceMesh instead of ProcessGroup to collectives
   (expert_parallel.py):
   - Change all_to_all_single/all_to_all_single_autograd calls to
     pass `device_mesh` directly instead of `device_mesh.get_group()`.
     This lets PyTorch's CooR-aware `_resolve_group` dispatch path
     derive ProcessGroups from DeviceMesh inputs via the
     `mesh_get_process_group` custom op, making the graph
     rank-agnostic without needing custom PG pickling.

2. MoE guard error handling (precompile_main.py):
   - Wrap the forward pass in try/except since MoE expert routing
     produces dynamic token counts via all_to_all_single, whose
     Inductor shape guards fail on fake tensors
   - Only suppress the error when the on_compile callback ran
     successfully (compile_succeeded flag), not just when the
     artifact file exists on disk

3. ReordererSequenceParallel CooR fix (expert_parallel.py):
   - Use `_sym_get_coordinate(0)` instead of `get_local_rank()` so
     the rank coordinate is a SymInt under compile_on_one_rank.
     Without this, get_local_rank() returns a concrete int that gets
     baked into the FX graph, causing CooR precompile to use rank 0's
     token slice indices on all ranks.

4. Extract build_forward_extra_kwargs() (forward_utils.py):
   - Factor the extra_kwargs construction (positions, attention_masks)
     out of Trainer.post_dataloading_process() into a standalone
     function in torchtitan/components/forward_utils.py.
   - Trainer, Validator, and precompile_main all call this shared
     function, eliminating duplicated config-inspection logic.
   - Precompile uses a _DummyTokenizer(eos_id=0) since it only needs
     structurally-matching masks, not real values.
   - Resolves the TODO in validate.py about deduplicating with
     Trainer.post_dataloading_process.

5. Test + docs (run_precompile_tests.py, README.md):
   - Add DSv3 precompile test definitions with EP=4, ETP=1
   - Add DSv3 precompile example commands to README
   - Clarify TODO comments: DSv3 cudagraph limitation affects both
     precompile and plain AOT (MoE routing CPU tensor ops)

Note: Also requires a PyTorch fix to include `torchbind_constants`
(which contain the ProcessGroup) in `CompiledFxGraphConstants.unwrap()`
so they are injected into the loaded Inductor module during
deserialization.

E2E validation (1 node, 8x H100, --debug.deterministic --debug.seed=42):

  DSv3 AOT precompile, DP=4 TP=2 EP=4 ETP=1, 5 steps:
    step 1: loss=7.94874 grad_norm=3.7722
    step 2: loss=6.06074 grad_norm=4.1036
    step 3: loss=5.01055 grad_norm=3.2918
    step 4: loss=4.91436 grad_norm=2.7785
    step 5: loss=4.67803 grad_norm=2.7366
  Result: bitwise identical baseline vs precompile
  MFU: ~0.90% baseline vs ~0.89% precompile (equivalent)

MAST validation (CooR precompile vs AOT baseline, DSv3 debugmodel,
DP=4 TP=2 EP=4 ETP=1, 10 steps): https://pxl.cl/9jvMS

ghstack-source-id: b5d3593
Pull-Request: #2916
@bobrenjc93 bobrenjc93 marked this pull request as draft April 12, 2026 16:35
@bobrenjc93 bobrenjc93 marked this pull request as ready for review April 12, 2026 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant