Skip to content

Remove dtensor dependency in Tensor Parallel#43157

Merged
3outeille merged 17 commits intov5-test_tensor_parallel_moefrom
v5-tensor-parallel-dtensor
Jan 13, 2026
Merged

Remove dtensor dependency in Tensor Parallel#43157
3outeille merged 17 commits intov5-test_tensor_parallel_moefrom
v5-tensor-parallel-dtensor

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Jan 7, 2026

Motivation

  • API stays the same but implementation details are different. It will be easier for us to debug whihc distributed calls are called we differentiated them by ourselves (example of VocabParallelEmbedding) and thus can put a breakpoint.

  • Benchmark shows speedup with torch.compile on/off

  • An example with ColwiseParallel

class _AllReduceBackward(torch.autograd.Function):
    """Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""

    @staticmethod
    def forward(ctx, x, device_mesh):
        ctx.device_mesh = device_mesh
        return x

    @staticmethod
    def backward(ctx, grad_output):
        device_mesh = ctx.device_mesh
        if device_mesh.size() == 1:
            return grad_output, None
        dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
        return grad_output, None


def all_reduce_backward(x, device_mesh):
    """Identity forward, all-reduce backward. Use before colwise layers."""
    return _AllReduceBackward.apply(x, device_mesh)
    
class ColwiseParallel(TensorParallelLayer):
    """
    Column-wise parallel: weight is sharded on dim -2 (output features).
    Forward: input replicated -> output sharded on last dim.
    If gather_output=True, output is all-gathered to produce full tensor.
    """

    def __init__(self, gather_output: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.gather_output = gather_output

    def _prepare_input_fn(self, mod, inputs, device_mesh):
        input_tensor = inputs[0] if inputs else inputs
        return all_reduce_backward(input_tensor, device_mesh)

    def _prepare_output_fn(self, mod, outputs, device_mesh):
        if self.gather_output:
            return all_gather(outputs, device_mesh)
        return outputs

    def shard_tensor(
        self,
        param,
        param_type=None,
        param_casting_dtype=None,
        to_contiguous=None,
        rank=None,
        device_mesh=None,
        tensor_idx=None,
    ):
        device_mesh = self.device_mesh
        empty_param = self.empty_param
        rank = self.rank
        if param_type == "bias":
            parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
        else:
            parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
        parameter = parameter.to(param_casting_dtype)
        return parameter, None

    def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
        distribute_module(
            module,
            device_mesh,
            self._prepare_input_fn,
            self._prepare_output_fn,
        )

Benchmark

  • Scripts
#!/usr/bin/env python
"""
Profiling script for Tensor Parallel using torch.profiler.

Run on BOTH branches:

    # WITH DTensor (base branch)
    git checkout v5-test_tensor_parallel_moe
    torchrun --nproc_per_node=2 profile_tp_trace.py

    # WITHOUT DTensor (PR branch)
    git checkout v5-tensor-parallel-dtensor
    torchrun --nproc_per_node=2 profile_tp_trace.py

    # With torch.compile enabled
    torchrun --nproc_per_node=2 profile_tp_trace.py --compile

Open traces with Perfetto UI: https://ui.perfetto.dev/
"""

import argparse
import os
import subprocess

import torch
import torch.distributed as dist
from torch.profiler import ProfilerActivity, profile
from transformers import AutoModelForCausalLM


def get_rank():
    if dist.is_initialized():
        return dist.get_rank()
    return 0


def log_rank0(msg):
    if get_rank() == 0:
        print(msg)


def get_current_branch():
    """Get current git branch name."""
    try:
        result = subprocess.run(
            ["git", "rev-parse", "--abbrev-ref", "HEAD"],
            capture_output=True,
            text=True,
            check=True,
        )
        return result.stdout.strip()
    except Exception:
        return "unknown"


def parse_args():
    parser = argparse.ArgumentParser(description="Profile Tensor Parallel with torch.profiler")
    parser.add_argument(
        "--compile",
        action="store_true",
        help="Enable torch.compile for the model",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    
    model_id = "meta-llama/Llama-3.2-3B"
    batch_size = 4
    seq_length = 512
    output_dir = "./profiler_traces"
    
    torch.manual_seed(42)
    
    branch = get_current_branch()
    
    log_rank0("=" * 60)
    log_rank0("Tensor Parallel Profiling with torch.profiler")
    log_rank0("=" * 60)
    log_rank0(f"Branch: {branch}")
    log_rank0(f"Model: {model_id}")
    log_rank0(f"torch.compile: {args.compile}")
    
    # Load model with tensor parallel
    log_rank0("\nLoading model with tensor parallel...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        dtype=torch.float16,
        tp_plan="auto",
    )
    
    # Apply torch.compile if requested
    if args.compile:
        log_rank0("Compiling model with torch.compile...")
        model = torch.compile(model)
    
    dist.barrier()
    device = model.device
    rank = get_rank()
    log_rank0(f"Model loaded on device: {device}")
    
    # Create input tensors
    vocab_size = model.config.vocab_size
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_length), device=device)
    labels = torch.randint(0, vocab_size, (batch_size, seq_length), device=device)
    
    model.train()
    
    # Warmup
    log_rank0("\nRunning warmup...")
    for _ in range(3):
        model.zero_grad()
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
    torch.cuda.synchronize()
    dist.barrier()
    
    # Profile one forward + backward pass
    log_rank0("\nProfiling forward + backward pass...")
    
    os.makedirs(output_dir, exist_ok=True)
    
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
        with_flops=True,
    ) as prof:
        model.zero_grad()
        
        # Forward pass
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        
        torch.cuda.synchronize()
    
    dist.barrier()
    
    # Export trace with branch name and compile status
    compile_suffix = "_compiled" if args.compile else ""
    trace_file = os.path.join(output_dir, f"trace_{branch}{compile_suffix}_rank{rank}.json")
    prof.export_chrome_trace(trace_file)
    log_rank0(f"\nTrace saved to: {trace_file}")
    
    # Print summary table
    if rank == 0:
        compile_status = " (compiled)" if args.compile else ""
        print("\n" + "=" * 60)
        print(f"PROFILER SUMMARY - {branch}{compile_status}")
        print("=" * 60)
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
    
    log_rank0(f"\nOpen trace with Perfetto UI: https://ui.perfetto.dev/")


if __name__ == "__main__":
    main()
  • Baseline
image
  • No Dtensor
image
  • Baseline + torch.compile
image
  • No Dtensor + torch.compile
image

@3outeille 3outeille marked this pull request as ready for review January 9, 2026 08:32
@3outeille 3outeille changed the base branch from main to v5-test_tensor_parallel_moe January 9, 2026 08:33
@github-actions github-actions Bot requested review from MekkCyber and SunMarc January 9, 2026 08:33
@3outeille 3outeille force-pushed the v5-tensor-parallel-dtensor branch from 5694039 to b135ba0 Compare January 9, 2026 10:42
@3outeille
Copy link
Copy Markdown
Member Author

run-slow: afmoe, apertus, arcee, aria, bamba, cohere, cohere2, cwm, dbrx, deepseek_v2, deepseek_v3, diffllama, doge, dots1, emu3, ernie4_5

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jan 9, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/afmoe", "models/apertus", "models/arcee", "models/aria", "models/bamba", "models/cohere", "models/cohere2", "models/cwm", "models/dbrx", "models/deepseek_v2", "models/deepseek_v3", "models/diffllama", "models/doge", "models/dots1", "models/emu3", "models/ernie4_5"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jan 9, 2026

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

…arameter handling in `set_param_for_module` and updated tensor sharding functions. Removed deprecated code and added new utility functions for block size calculations.
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will review in details later on, make sure you write explicitly the motivation behind removing dtensor please (In terms of API etc) !


# Remove from missing keys (it's either mismatched, or all good)
missing_keys.discard(target_name)
# Skip shape check when tensor parallel sharding is applied (shape is intentionally different)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not really optimal, we should build / copy the utils or just create a dummy Dtensor with meta device just to let it handle the shape! But we cannot ship without shape checking IMO!

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, apertus, arcee, aria, bamba, cohere, cohere2, cwm, dbrx, deepseek_v2, deepseek_v3, diffllama, doge, dots1, emu3, ernie4_5

@3outeille
Copy link
Copy Markdown
Member Author

run-slow: afmoe, apertus, arcee, aria, bamba, cohere, cohere2, cwm, dbrx, deepseek_v2, deepseek_v3, diffllama, doge, dots1, emu3, ernie4_5

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/afmoe", "models/apertus", "models/arcee", "models/aria", "models/bamba", "models/cohere", "models/cohere2", "models/cwm", "models/dbrx", "models/deepseek_v2", "models/deepseek_v3", "models/diffllama", "models/doge", "models/dots1", "models/emu3", "models/ernie4_5"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@3outeille
Copy link
Copy Markdown
Member Author

3outeille commented Jan 13, 2026

merging this PR into #42809 to unblock me. It will provide more more thorough testing

@3outeille 3outeille merged commit b0c5a98 into v5-test_tensor_parallel_moe Jan 13, 2026
27 checks passed
@3outeille 3outeille deleted the v5-tensor-parallel-dtensor branch January 13, 2026 08:08
@3outeille 3outeille restored the v5-tensor-parallel-dtensor branch January 13, 2026 08:08
@ArthurZucker
Copy link
Copy Markdown
Collaborator

Ok!

3outeille added a commit that referenced this pull request Jan 30, 2026
* begin Moe test tensor parallel

* create tiny moe model + fix test tensor parallel Moe

eaeaae

* create tiny moe model + fix test tensor parallel Moe

eaeaae

fix tensor parallel MoE test
fix tensor parallel MoE test

* fix backward pass test in tensor parallel for Dense model (#42811)

* fix

* linting

* use mixtral instead for testing

* fix dtensor and tensor mismatch

* linting

* checkout test tensor parallel to be like main

* avoid hack and create class instead

* fix loading ep

* add moe test

* now EP inference works again but pass still fails

* linting

* now load from checkpoint. Creating a nn.Parameter for param_value will not transfer its attribute (especially _is_hf_initialized)

* forward now works (add LocalPackedColwise + dont use EP router)

* for now test in float32

* dont do all_reduce manually for GatherParellel. Convert to dtensor approach

* Remove dtensor dependency in Tensor Parallel (#43157)

* dense test is passing

* Refactor tensor parallel implementation by removing unused partition_tensor methods

* keep removing dependencies on Dtensor

* rename test file

* Update tensor parallel plans to use "colwise_gather_output" across multiple models

* Remove unused "gather" references and update tensor parallel plans to "colwise_gather_output" in multiple model configurations.

* Refactor tensor parallel plans in Fbgemm and FineGrained quantizers by removing unused configurations and comments related to "gather" operations.

* add 'split_input' option in RowwiseParallel + replace rowwise_replicate 'rowwise_split_input'

* Add PackedColwiseParallel and PackedRowwiseParallel + Update configuration plans

* mixing files and some fix for tp and tp_plan

* clean tensor paralle api

* linting

* linting

* Refactor core model loading and tensor parallel utilities. Improved parameter handling in `set_param_for_module` and updated tensor sharding functions. Removed deprecated code and added new utility functions for block size calculations.

* code quality

* make fixup

* tp workf for dense and moe in float32 only

* fix merge conflicts that broke TP

* revert parsing for tp plan

* all reduce after experts

* compile compatible dist ops

* fix gate_up_proj gradient test by doing splitting thtat takes into account that it is fused  + all_reduce to get full gradient before functional.linear

* fix moe backward fp32

* remove functional.Linear to use nn.Linear in experts (this way we attach hooks)

* moe work with tied embedding as well

* style

* all tests pass

* make fix-up

* typo

* use transformer seed + pytest parametrized

* Moved weight and bias dim mapping to ParallelInterface

* simplifed shard tensor signature

* sync shard_tensor logic with the one in origin/main

* add function check to avoid mismatch check during  set_param_for_module

* remove disable. I was in an older torch version

* Add pytest skip condition for tensor parallel tests requiring PyTorch >= 2.9

* linting

* linting

* fixing remaining modular

* linting

* Refactor get_expected_sharded_shape to be only one call

* Remove redundant prepare_module_tp method from TensorParallelLayer subclasses

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <arthur.zucker@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants