TP refactor for FSDP + TP integration#45028
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
fcea5ce to
f98e208
Compare
- DtensorShardOperation for range-math shard-on-read - spawn_materialize() enhancements - from_pretrained wiring for distributed config - Shard operation helpers in tensor_parallel - Shard-on-read and LoadStateDictConfig tests
607cc11 to
739332c
Compare
- Replace hook-based TP with DTensor-based TPStyle API - TPStyle dataclass with dense kinds: colwise, rowwise, vocab - apply_tensor_parallel() using PyTorch parallelize_module - verify_tp_plan() for plan validation - Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle - DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3 - Extended DistributedConfig with tp/fsdp size and plan fields - DistributedConfig serialization in configuration_utils - MXFP4 NotImplementedError for DTensor TP - Dense TP tests
1aa7f5f to
11b55a2
Compare
dbc9619 to
c567240
Compare
34a5085 to
eb428cc
Compare
c567240 to
c1dab9e
Compare
eb428cc to
e0c4e06
Compare
* MoE expert parallelism + sequence parallelism - Add PackedColwiseParallel for fused gate_up_proj weights - Add MoEExpertsParallel with per-expert DTensor sharding - Add PrepareModuleInputOutput for SP allgather/split hooks - Add _AllReduceBackward for MoE routing weight gradients - Extend TPStyle with moe_experts, packed_colwise, activation, module kinds - _StridedShard handling in core_model_loading for interleaved weights - MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans - DTensor rotary_pos_emb guard for mixtral * Fix ruff linting and formatting * Fix ruff formatting in core_model_loading.py * Restore _IdentityOp accidentally removed in 25a1f48 The _IdentityOp class (added by PR #44983) was accidentally deleted during the MoE expert parallelism work. It is needed by finegrained_fp8.py and metal_quantization.py as a pass-through reverse_op for dequantize operations. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Backport new TP/FSDP API + fix DTensor imports in Copied-from models * from_pretrained orchestration + distributed save/load (#45409) * from_pretrained orchestration + save/load - Add gather_full_state_dict() for DTensor→full tensor saving - Add convert_strided_to_shard() / restore_strided_from_shard() for DCP - Add _redistribute_dtensor() helper - Full distributed_config integration in from_pretrained/save_pretrained - Rename apply_fsdp2 → apply_fully_shard_data_parallel - save_optimizer() / load_optimizer() in distributed/utils - Trainer integration with distributed_config - Updated FSDP and TP tests for new orchestration API - DTensor shard-on-read test updates * revert distributed utils * eaaea * all tests for core modeling are passing * populate import from init for tp * ruff * ruff --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| return | ||
|
|
||
| # Filter out module-level comm hooks — they don't shard weights | ||
| _NON_WEIGHT_KINDS = {"activation", "module"} |
There was a problem hiding this comment.
maybe separate tp and sp style ?
Restores modeling files to their base branch versions so the PR diff only shows the distributed/patches.py monkey-patch approach instead of noisy function moves in modeling files. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
"colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter
| return model | ||
|
|
||
| if isinstance(fsdp_plan, str): | ||
| if fsdp_plan == "auto": |
There was a problem hiding this comment.
define fsdp_plan in every model and remove auto code path (will be done in the next PR)
…m_head when embed_tokens is not in the plan.
ArthurZucker
left a comment
There was a problem hiding this comment.
Okay reviewed only core model loading
| model=model, | ||
| missing_keys=loading_info.missing_keys if loading_info else None, | ||
| ) | ||
| if len(collected_tensors) > 1 and model is not None: |
There was a problem hiding this comment.
:) this does not make sense to me
There was a problem hiding this comment.
at this point, only params that need quantizatoin have self,.quantization operation attached to them
| return _job | ||
|
|
||
|
|
||
| class DtensorShardOperation: |
There was a problem hiding this comment.
should be in TP or sharding utils not here
|
|
||
| (b) One expert — source.ndim == param.ndim - 1 | ||
| MoE models stack experts along a leading axis (E, ...) in the | ||
| model, but checkpoints store each expert in its own file |
There was a problem hiding this comment.
not all checkpoint store them this way.
| (b) + (c) co-occurring | ||
| MoE checkpoint that is both per-expert and pre-pack (e.g. | ||
| `experts.2.w1.weight`). Resolve the expert axis first (b); the | ||
| generic loop then handles the remaining dims with (c) behavior. | ||
| """ |
There was a problem hiding this comment.
Okay this is a great start but :
There was a problem hiding this comment.
we want each case to be defined by a different string / by checking the operations that are gonna be applied to the layer, not by having a very general / fit all approach!
| 2) Multi-interval on one dim — read each piece, concat on that dim: | ||
| source shape = [8, 4] | ||
| intervals = [[(0, 2), (4, 6)], [(0, 4)]] | ||
| → cat([source[0:2, 0:4], source[4:6, 0:4]], dim=0) |
There was a problem hiding this comment.
I really don't think we need to cat....
out = torch.empty((4, 4), dtype=source.dtype, device=source.device)
out[0:2] = source[0:2, 0:4]
out[2:4] = source[4:6, 0:4]worst case,
source[[0,1,4,5], 0:4]does the same as well...
| else: | ||
| expected_shape = ref.shape | ||
|
|
||
| # When a WeightConverter produces the full global tensor, slice it to the local DTensor shard. |
There was a problem hiding this comment.
mmm I don't understand, how can that happen>
| return renamed_key, source_pattern | ||
|
|
||
|
|
||
| def concretize_target_patterns( |
There was a problem hiding this comment.
remind me why we are touching this?
| empty_param = meta_model_state_dict[renamed_key] | ||
| try: | ||
| empty_param = model.get_parameter_or_buffer(renamed_key) | ||
| except (AttributeError, KeyError): | ||
| if getattr(model, "_is_fsdp_managed_module", False): | ||
| raise RuntimeError( | ||
| f"FSDP shard-on-read requires the live parameter for {renamed_key!r}, " | ||
| f"but get_parameter_or_buffer() failed." | ||
| ) |
There was a problem hiding this comment.
anything fsdp specific should avoid taking space / being directly here
|
|
||
| def _job(): | ||
| return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype) | ||
| def _strided_intervals( |
There was a problem hiding this comment.
for strided, the most efficient is prob:
- have a staging buffer, copy slice to it
- copy twice (since the rows are in cache)
There was a problem hiding this comment.
we'll check this with @McPatate as well potentially but let's do what we can on our side!
| pieces_read.append(source[tuple(piece_slices)]) | ||
| return torch.cat(pieces_read, dim=multi_interval_dim).to(device=device, dtype=dtype) | ||
|
|
||
| def _owns_expert(self, expert_idx: int) -> bool: |
There was a problem hiding this comment.
does not make sense to have this here its specific
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, apertus, arcee, aria, audioflamingo3, bamba, bitnet |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45028&sha=5b336b |
verify_all_loss->trainingwith saving + loading back for generate ?)Verify loading
Training