Remove dtensor dependency in Tensor Parallel#43157
Remove dtensor dependency in Tensor Parallel#431573outeille merged 17 commits intov5-test_tensor_parallel_moefrom
Conversation
… "colwise_gather_output" in multiple model configurations.
…y removing unused configurations and comments related to "gather" operations.
…te 'rowwise_split_input'
5694039 to
b135ba0
Compare
|
run-slow: afmoe, apertus, arcee, aria, bamba, cohere, cohere2, cwm, dbrx, deepseek_v2, deepseek_v3, diffllama, doge, dots1, emu3, ernie4_5 |
|
This comment contains models: ["models/afmoe", "models/apertus", "models/arcee", "models/aria", "models/bamba", "models/cohere", "models/cohere2", "models/cwm", "models/dbrx", "models/deepseek_v2", "models/deepseek_v3", "models/diffllama", "models/doge", "models/dots1", "models/emu3", "models/ernie4_5"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
…arameter handling in `set_param_for_module` and updated tensor sharding functions. Removed deprecated code and added new utility functions for block size calculations.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Will review in details later on, make sure you write explicitly the motivation behind removing dtensor please (In terms of API etc) !
|
|
||
| # Remove from missing keys (it's either mismatched, or all good) | ||
| missing_keys.discard(target_name) | ||
| # Skip shape check when tensor parallel sharding is applied (shape is intentionally different) |
There was a problem hiding this comment.
this is not really optimal, we should build / copy the utils or just create a dummy Dtensor with meta device just to let it handle the shape! But we cannot ship without shape checking IMO!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, apertus, arcee, aria, bamba, cohere, cohere2, cwm, dbrx, deepseek_v2, deepseek_v3, diffllama, doge, dots1, emu3, ernie4_5 |
|
run-slow: afmoe, apertus, arcee, aria, bamba, cohere, cohere2, cwm, dbrx, deepseek_v2, deepseek_v3, diffllama, doge, dots1, emu3, ernie4_5 |
|
This comment contains models: ["models/afmoe", "models/apertus", "models/arcee", "models/aria", "models/bamba", "models/cohere", "models/cohere2", "models/cwm", "models/dbrx", "models/deepseek_v2", "models/deepseek_v3", "models/diffllama", "models/doge", "models/dots1", "models/emu3", "models/ernie4_5"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
|
merging this PR into #42809 to unblock me. It will provide more more thorough testing |
|
Ok! |
* begin Moe test tensor parallel * create tiny moe model + fix test tensor parallel Moe eaeaae * create tiny moe model + fix test tensor parallel Moe eaeaae fix tensor parallel MoE test fix tensor parallel MoE test * fix backward pass test in tensor parallel for Dense model (#42811) * fix * linting * use mixtral instead for testing * fix dtensor and tensor mismatch * linting * checkout test tensor parallel to be like main * avoid hack and create class instead * fix loading ep * add moe test * now EP inference works again but pass still fails * linting * now load from checkpoint. Creating a nn.Parameter for param_value will not transfer its attribute (especially _is_hf_initialized) * forward now works (add LocalPackedColwise + dont use EP router) * for now test in float32 * dont do all_reduce manually for GatherParellel. Convert to dtensor approach * Remove dtensor dependency in Tensor Parallel (#43157) * dense test is passing * Refactor tensor parallel implementation by removing unused partition_tensor methods * keep removing dependencies on Dtensor * rename test file * Update tensor parallel plans to use "colwise_gather_output" across multiple models * Remove unused "gather" references and update tensor parallel plans to "colwise_gather_output" in multiple model configurations. * Refactor tensor parallel plans in Fbgemm and FineGrained quantizers by removing unused configurations and comments related to "gather" operations. * add 'split_input' option in RowwiseParallel + replace rowwise_replicate 'rowwise_split_input' * Add PackedColwiseParallel and PackedRowwiseParallel + Update configuration plans * mixing files and some fix for tp and tp_plan * clean tensor paralle api * linting * linting * Refactor core model loading and tensor parallel utilities. Improved parameter handling in `set_param_for_module` and updated tensor sharding functions. Removed deprecated code and added new utility functions for block size calculations. * code quality * make fixup * tp workf for dense and moe in float32 only * fix merge conflicts that broke TP * revert parsing for tp plan * all reduce after experts * compile compatible dist ops * fix gate_up_proj gradient test by doing splitting thtat takes into account that it is fused + all_reduce to get full gradient before functional.linear * fix moe backward fp32 * remove functional.Linear to use nn.Linear in experts (this way we attach hooks) * moe work with tied embedding as well * style * all tests pass * make fix-up * typo * use transformer seed + pytest parametrized * Moved weight and bias dim mapping to ParallelInterface * simplifed shard tensor signature * sync shard_tensor logic with the one in origin/main * add function check to avoid mismatch check during set_param_for_module * remove disable. I was in an older torch version * Add pytest skip condition for tensor parallel tests requiring PyTorch >= 2.9 * linting * linting * fixing remaining modular * linting * Refactor get_expected_sharded_shape to be only one call * Remove redundant prepare_module_tp method from TensorParallelLayer subclasses --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Arthur <arthur.zucker@gmail.com>
Motivation
API stays the same but implementation details are different. It will be easier for us to debug whihc distributed calls are called we differentiated them by ourselves (example of
VocabParallelEmbedding) and thus can put a breakpoint.Benchmark shows speedup with
torch.compileon/offAn example with
ColwiseParallelBenchmark
torch.compiletorch.compile