fix: resolve PT 2.11 DeviceMesh deprecation warnings and unify EP mesh#1684
Merged
hemildesai merged 12 commits intomainfrom Apr 7, 2026
Merged
fix: resolve PT 2.11 DeviceMesh deprecation warnings and unify EP mesh#1684hemildesai merged 12 commits intomainfrom
hemildesai merged 12 commits intomainfrom
Conversation
Contributor
Author
|
/claude review |
Contributor
Author
|
/ok to test 817e933 |
Contributor
Author
|
/claude review |
Contributor
Author
|
/ok to test d61076c |
Contributor
Author
|
/claude review |
Contributor
Author
|
/claude review |
Contributor
Author
|
/ok to test 4206ca6 |
Contributor
Author
|
/ok to test deb954c |
Contributor
Author
|
/ok to test 966405c |
Contributor
Author
|
/ok to test 2fb29da |
Two PyTorch 2.11 deprecation warnings fired on every training run: 1. `_mesh_resources.get_root_mesh()` deprecated in favor of `DeviceMesh._get_root_mesh()` 2. `root_mesh["flattened_dim"]` deprecated for dims created via `_flatten()` Additionally, the MoE mesh was created as a standalone `init_device_mesh` call separate from the main device mesh, requiring a redundant global collective and making TP+EP coexistence impossible. Changes: - Add `get_flat_mesh()` and `get_submesh()` utilities in mesh_utils.py that access `_flatten()` results directly via `_flatten_mapping`, and construct mixed-dim submeshes via `_unflatten()` from a parent flattened mesh - Replace standalone `_create_moe_mesh()` with `_unflatten()` from the root mesh's non-pp dims, deriving EP process groups from the same mesh hierarchy - EP mesh now spans dp, cp, and tp groups (matching the old standalone mesh semantics and enabling future TP+EP support) - Consolidate `state_dict_utils.get_submesh` as a re-export of the shared `mesh_utils.get_submesh` - Update all callers: base_recipe, parallelizer (FSDP2 + MoE), vlm/finetune, optim/utils, mesh.py axis size lookups Validated with: - 918 unit tests passing - Qwen3 MoE 30B EP=8 (full + LoRA), LLaMA 3.1 8B PP=2 end-to-end training - Multi-process verification that unified EP groups match standalone mesh - Zero deprecation warnings across all runs Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
- Remove unused dp_cp_mesh assignment (ruff F841) - Fix import ordering in parallelizer.py (ruff I001) - Keep cross-component imports lazy to satisfy import-linter rules (moe -> distributed, optim -> distributed) - Harden get_submesh size matching with try/except on _unflatten to handle ambiguous size collisions - Consolidate state_dict_utils.get_submesh as thin wrapper delegating to mesh_utils.get_submesh Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
Address claude-bot review: size-only matching in get_submesh could pick the wrong parent flattened mesh if two entries have equal total size. After _unflatten, now validates that process groups for any root-mesh dim in the result match the root mesh's groups for that dim. Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
…tions - Revert state_dict_utils.get_submesh to inline impl (avoid cross-component import from moe -> distributed that breaks import-linter) - Revert optim/utils.py to original approach (avoid optim -> distributed import) - Remove stale lazy import in WanParallelizationStrategy - Restore exact dp_mesh assertions in Wan strategy tests by monkeypatching get_submesh to return a known sentinel object Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
Replace fallback paths in get_flat_mesh and get_submesh that would silently trigger the PT 2.11 deprecation warning with explicit KeyError. If a dim is not found in mesh_dim_names or _flatten_mapping, it is a caller error rather than something to silently degrade on. Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
…et_submesh Test mocks for DeviceMesh now include mesh_dim_names, _flatten_mapping, and _get_root_mesh so that get_flat_mesh/get_submesh can resolve dims. Also patch dist.get_process_group_ranks in strategy integration tests for the get_submesh validation step. Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
…ation - Validate process groups for ALL requested dims (not just mesh dims) by using get_flat_mesh for both mesh and flattened dim lookups - Fix 2-space indentation in test_strategy_integration.py Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
FakeWorldMesh.__getitem__ now handles both string "dp" and tuple
("dp",) lookups, since get_flat_mesh passes dim names as strings.
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
The docs finetune guide (added in #1678) uses :::{details} which requires the html_admonition myst extension. Without it, sphinx --fail-on-warning rejects the unknown directive. Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
…design
The {details} directive doesn't exist in myst-parser. Replace with
{dropdown} from sphinx-design (already in extensions) which provides
the same collapsible UI. Also revert the unnecessary html_admonition
extension addition.
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Contributor
Author
|
/ok to test 41bc246 |
Contributor
Author
|
/claude review |
adil-a
approved these changes
Apr 6, 2026
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
_mesh_resources.get_root_mesh()→ use_get_root_mesh()/_flatten_mappingdirectlyroot_mesh["flattened_dim"]→ access_flatten()results viaget_flat_mesh()/get_submesh()utilities_create_moe_mesh()(init_device_mesh) with_unflatten()from the root mesh's non-pp dims, deriving EP process groups from the same mesh hierarchyDesign
Two utility functions in
mesh_utils.py:get_flat_mesh(device_mesh, name)— reads_flatten()results from_flatten_mappingfor 1D accessget_submesh(device_mesh, names)— handles multi-dim tuples by finding a parent_flatten()result and calling_unflatten()to decompose itUnified EP mesh creation:
Test plan
init_device_mesh🤖 Generated with Claude Code