Skip to content

fix: resolve PT 2.11 DeviceMesh deprecation warnings and unify EP mesh#1684

Merged
hemildesai merged 12 commits intomainfrom
hemild/fix-devicemesh-deprecation-unified-ep
Apr 7, 2026
Merged

fix: resolve PT 2.11 DeviceMesh deprecation warnings and unify EP mesh#1684
hemildesai merged 12 commits intomainfrom
hemild/fix-devicemesh-deprecation-unified-ep

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

Summary

  • Fix two PyTorch 2.11 deprecation warnings that fire on every training run:
    • _mesh_resources.get_root_mesh() → use _get_root_mesh() / _flatten_mapping directly
    • root_mesh["flattened_dim"] → access _flatten() results via get_flat_mesh() / get_submesh() utilities
  • Unify MoE mesh with main device mesh: replace standalone _create_moe_mesh() (init_device_mesh) with _unflatten() from the root mesh's non-pp dims, deriving EP process groups from the same mesh hierarchy
  • EP mesh now spans dp, cp, and tp groups, matching the old standalone mesh semantics and enabling future TP+EP support

Design

Two utility functions in mesh_utils.py:

  • get_flat_mesh(device_mesh, name) — reads _flatten() results from _flatten_mapping for 1D access
  • get_submesh(device_mesh, names) — handles multi-dim tuples by finding a parent _flatten() result and calling _unflatten() to decompose it

Unified EP mesh creation:

# Old: standalone init_device_mesh (separate global collective, required tp=1)
moe_mesh = init_device_mesh("cuda", (pp, ep_shard, ep), ...)

# New: derived from root mesh via _flatten() + _unflatten()
non_pp_mesh = device_mesh[("dp_replicate", "dp_shard", "cp", "tp")]._flatten()
moe_mesh = non_pp_mesh._unflatten(0, (ep_shard_size, ep_size), ("ep_shard", "ep"))

Test plan

  • 918 unit tests passing (distributed, moe, optim)
  • Qwen3-30B-A3B MoE EP=8 full finetune — training converging, 0 warnings
  • Qwen3-30B-A3B MoE EP=8 LoRA finetune — training converging, 0 warnings
  • LLaMA 3.1 8B PP=2 finetune — training converging, 0 warnings
  • Multi-process verification: unified EP groups identical to standalone init_device_mesh
  • Verified EP spans TP groups (enables future TP+EP)

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 817e933

Comment thread nemo_automodel/components/distributed/mesh_utils.py Outdated
Comment thread nemo_automodel/components/distributed/mesh_utils.py
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test d61076c

Comment thread nemo_automodel/components/distributed/mesh_utils.py
Comment thread tests/unit_tests/distributed/test_parallelization_strategies.py Outdated
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread nemo_automodel/components/distributed/mesh_utils.py
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 4206ca6

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test deb954c

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 966405c

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 2fb29da

hemildesai and others added 12 commits April 5, 2026 20:48
Two PyTorch 2.11 deprecation warnings fired on every training run:
1. `_mesh_resources.get_root_mesh()` deprecated in favor of `DeviceMesh._get_root_mesh()`
2. `root_mesh["flattened_dim"]` deprecated for dims created via `_flatten()`

Additionally, the MoE mesh was created as a standalone `init_device_mesh` call
separate from the main device mesh, requiring a redundant global collective and
making TP+EP coexistence impossible.

Changes:
- Add `get_flat_mesh()` and `get_submesh()` utilities in mesh_utils.py that
  access `_flatten()` results directly via `_flatten_mapping`, and construct
  mixed-dim submeshes via `_unflatten()` from a parent flattened mesh
- Replace standalone `_create_moe_mesh()` with `_unflatten()` from the root
  mesh's non-pp dims, deriving EP process groups from the same mesh hierarchy
- EP mesh now spans dp, cp, and tp groups (matching the old standalone mesh
  semantics and enabling future TP+EP support)
- Consolidate `state_dict_utils.get_submesh` as a re-export of the shared
  `mesh_utils.get_submesh`
- Update all callers: base_recipe, parallelizer (FSDP2 + MoE), vlm/finetune,
  optim/utils, mesh.py axis size lookups

Validated with:
- 918 unit tests passing
- Qwen3 MoE 30B EP=8 (full + LoRA), LLaMA 3.1 8B PP=2 end-to-end training
- Multi-process verification that unified EP groups match standalone mesh
- Zero deprecation warnings across all runs

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
- Remove unused dp_cp_mesh assignment (ruff F841)
- Fix import ordering in parallelizer.py (ruff I001)
- Keep cross-component imports lazy to satisfy import-linter rules
  (moe -> distributed, optim -> distributed)
- Harden get_submesh size matching with try/except on _unflatten
  to handle ambiguous size collisions
- Consolidate state_dict_utils.get_submesh as thin wrapper delegating
  to mesh_utils.get_submesh

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Address claude-bot review: size-only matching in get_submesh could pick
the wrong parent flattened mesh if two entries have equal total size.
After _unflatten, now validates that process groups for any root-mesh dim
in the result match the root mesh's groups for that dim.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…tions

- Revert state_dict_utils.get_submesh to inline impl (avoid cross-component
  import from moe -> distributed that breaks import-linter)
- Revert optim/utils.py to original approach (avoid optim -> distributed import)
- Remove stale lazy import in WanParallelizationStrategy
- Restore exact dp_mesh assertions in Wan strategy tests by monkeypatching
  get_submesh to return a known sentinel object

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Replace fallback paths in get_flat_mesh and get_submesh that would
silently trigger the PT 2.11 deprecation warning with explicit KeyError.
If a dim is not found in mesh_dim_names or _flatten_mapping, it is a
caller error rather than something to silently degrade on.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…et_submesh

Test mocks for DeviceMesh now include mesh_dim_names, _flatten_mapping,
and _get_root_mesh so that get_flat_mesh/get_submesh can resolve dims.
Also patch dist.get_process_group_ranks in strategy integration tests
for the get_submesh validation step.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…ation

- Validate process groups for ALL requested dims (not just mesh dims)
  by using get_flat_mesh for both mesh and flattened dim lookups
- Fix 2-space indentation in test_strategy_integration.py

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
FakeWorldMesh.__getitem__ now handles both string "dp" and tuple
("dp",) lookups, since get_flat_mesh passes dim names as strings.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
The docs finetune guide (added in #1678) uses :::{details} which
requires the html_admonition myst extension. Without it, sphinx
--fail-on-warning rejects the unknown directive.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…design

The {details} directive doesn't exist in myst-parser. Replace with
{dropdown} from sphinx-design (already in extensions) which provides
the same collapsible UI. Also revert the unnecessary html_admonition
extension addition.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 41bc246

@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants