Skip to content

TP refactor for FSDP + TP integration#45028

Draft
3outeille wants to merge 52 commits intofsdp-vs-ddpfrom
refactor-tp-dtensor
Draft

TP refactor for FSDP + TP integration#45028
3outeille wants to merge 52 commits intofsdp-vs-ddpfrom
refactor-tp-dtensor

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Mar 26, 2026

  • TODO
    • how will dtensor works with quantization ?
    • how will dtensor works with kernels ?
    • Needs end to end test (combine verify_all_loss -> training with saving + loading back for generate ?)
    • double check Save FSDP + TP
      • do test that does FSDP + TP -> training -> saving -> loading back -> training (compare to a normal training end to end)

Verify loading

# python verify_loading.py --mode single_gpu
# torchrun --nproc_per_node=4 verify_loading.py --mode fsdp
# torchrun --nproc_per_node=4 verify_loading.py --mode tp
# torchrun --nproc_per_node=4 verify_loading.py --mode tp_fsdp
import argparse, os, torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.distributed import DistributedConfig

parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["single_gpu", "fsdp", "tp", "tp_fsdp"], required=True)
args = parser.parse_args()

if args.mode != "single_gpu":
    torch.distributed.init_process_group(backend="nccl")
    rank = int(os.environ["RANK"])
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    world_size = torch.distributed.get_world_size()
else:
    rank = 0
    torch.cuda.set_device(0)
    world_size = 1 # single GPU

configs = {
    "single_gpu": None,
    "fsdp": DistributedConfig(fsdp_size=world_size, fsdp_plan="auto"),
    "tp": DistributedConfig(tp_size=world_size, tp_plan="auto", enable_sequence_parallel=True),
    "tp_fsdp": DistributedConfig(tp_size=2, tp_plan="auto", fsdp_size=world_size // 2, fsdp_plan="auto", enable_sequence_parallel=True),
}

print(configs[args.mode])
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", distributed_config=configs[args.mode], dtype=torch.float32)
if args.mode == "single_gpu":
    model = model.to("cuda:0")

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
text = "The capital of France is Paris and the capital of Germany is Berlin and the capital of Italy is Rome"
inputs = tokenizer(text, return_tensors="pt").to(f"cuda:{rank}")

model.eval()
with torch.no_grad():
    loss = model(**inputs, labels=inputs["input_ids"].clone()).loss

if rank == 0:
    print(f"{args.mode}: loss = {loss.item():.4f}")

if args.mode != "single_gpu":
    torch.distributed.destroy_process_group()

Training

# torchrun --nproc_per_node=4 train_fsdp_tp.py

import os

import torch
import torch.distributed.checkpoint as dcp
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.distributed import DistributedConfig
from torchtitan.distributed import utils as dist_utils #TODO(3outeille): add this to transformers.distributed

def build_packed_dataset(dataset_name, tokenizer, seq_len, dp_rank, dp_world_size):
    """Stream + tokenize + greedy-pack documents into fixed-length (input, label) windows."""
    ds = load_dataset(dataset_name, name="en", split="train", streaming=True)
    ds = ds.shard(num_shards=dp_world_size, index=dp_rank)
    buf, w = [], seq_len + 1

    def pack(batch):
        for t in batch["text"]:
            buf.extend(tokenizer(t)["input_ids"])
        ids, lbls = [], []
        while len(buf) >= w:
            ids.append(buf[:seq_len]); lbls.append(buf[1:w]); del buf[:w]
        return {"input_ids": ids, "labels": lbls}

    ds = ds.map(pack, batched=True, remove_columns=ds.column_names)
    return ds.with_format("torch")


if __name__ == "__main__":

    model_name = "Qwen/Qwen3-0.6B"
    dataset_name = "allenai/c4"
    seq_len = 512
    num_steps, lr = 50, 3e-4
    batch_size = 4
    save_dir = "./checkpoints"

    torch.distributed.init_process_group(backend="nccl")
    rank, local_rank = int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    distributed_config = DistributedConfig(tp_size=2, tp_plan="auto", fsdp_size=2, fsdp_plan="auto", enable_sequence_parallel=True)
    # distributed_config = DistributedConfig(fsdp_size=4, fsdp_plan="auto")
    # distributed_config = DistributedConfig(tp_size=4, tp_plan="auto", enable_sequence_parallel=True)

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        distributed_config=distributed_config,
        torch_dtype=torch.bfloat16,
    )

    dp_rank = model.device_mesh["fsdp"].get_local_rank() if "fsdp" in model.device_mesh.mesh_dim_names else 0
    dp_world_size = model.device_mesh["fsdp"].size() if "fsdp" in model.device_mesh.mesh_dim_names else 1
    

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    dataset = build_packed_dataset(dataset_name, tokenizer, seq_len, dp_rank, dp_world_size)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()
    data_iterator = iter(dataloader)
    for step in range(num_steps):
        batch = next(data_iterator)
        input_ids = batch["input_ids"].to(f"cuda:{local_rank}")
        labels = batch["labels"].to(f"cuda:{local_rank}")

        loss = model(input_ids, labels=labels).loss
        loss.backward()
        grad_norm = dist_utils.clip_grad_norm_(list(model.parameters()), max_norm=1.0, foreach=True)
        optimizer.step()
        optimizer.zero_grad()

        if rank == 0:
            print(f"Step {step:>4d} | Loss: {loss.item():.4f} | Grad norm: {grad_norm.item():.4f}")

    # Save model (HF format) and optimizer (DCP)
    model.save_pretrained(save_dir)
    dcp.save({"optimizer": optimizer.state_dict()}, checkpoint_id=os.path.join(save_dir, "optimizer"))

    if rank == 0:
        print(f"Saved to {save_dir}")

    torch.distributed.destroy_process_group()
(env_refactor-tp-dtensor) ➜  refactor-tp-dtensor git:(refactor-tp-dtensor) ✗ torchrun --nproc_per_node=4 train_fsdp_tp.py 2>&1 | tee log.txt
W0404 15:36:13.989000 408584 torch/distributed/run.py:862] 
W0404 15:36:13.989000 408584 torch/distributed/run.py:862] *****************************************
W0404 15:36:13.989000 408584 torch/distributed/run.py:862] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0404 15:36:13.989000 408584 torch/distributed/run.py:862] *****************************************
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|██████████| 311/311 [00:00<00:00, 1332.48it/s]
Loading weights: 100%|██████████| 311/311 [00:00<00:00, 1393.62it/s]
Loading weights: 100%|██████████| 311/311 [00:00<00:00, 1506.23it/s]
Loading weights: 100%|██████████| 311/311 [00:00<00:00, 1343.11it/s]
[rank0]:W0404 15:36:38.769000 408764 torch/distributed/tensor/_redistribute.py:371] While redistributing from (_NormPartial(2.0), _NormPartial(2.0)) to (Replicate(), Replicate()), 2 sequential all_reduce operations will be performed. This is suboptimal: multiple collective operations have higher latency (separate kernel launches and synchronization points) and may give inconsistent results between ranks due to different reduction orders. To optimize, flatten mesh dimensions ["fsdp", "tp"] so DTensor can use a single operation instead.
[rank1]:W0404 15:36:38.773000 408765 torch/distributed/tensor/_redistribute.py:371] While redistributing from (_NormPartial(2.0), _NormPartial(2.0)) to (Replicate(), Replicate()), 2 sequential all_reduce operations will be performed. This is suboptimal: multiple collective operations have higher latency (separate kernel launches and synchronization points) and may give inconsistent results between ranks due to different reduction orders. To optimize, flatten mesh dimensions ["fsdp", "tp"] so DTensor can use a single operation instead.
[rank3]:W0404 15:36:38.775000 408767 torch/distributed/tensor/_redistribute.py:371] While redistributing from (_NormPartial(2.0), _NormPartial(2.0)) to (Replicate(), Replicate()), 2 sequential all_reduce operations will be performed. This is suboptimal: multiple collective operations have higher latency (separate kernel launches and synchronization points) and may give inconsistent results between ranks due to different reduction orders. To optimize, flatten mesh dimensions ["fsdp", "tp"] so DTensor can use a single operation instead.
[rank2]:W0404 15:36:38.775000 408766 torch/distributed/tensor/_redistribute.py:371] While redistributing from (_NormPartial(2.0), _NormPartial(2.0)) to (Replicate(), Replicate()), 2 sequential all_reduce operations will be performed. This is suboptimal: multiple collective operations have higher latency (separate kernel launches and synchronization points) and may give inconsistent results between ranks due to different reduction orders. To optimize, flatten mesh dimensions ["fsdp", "tp"] so DTensor can use a single operation instead.
Step    0 | Loss: 16.2702 | Grad norm: 5504.0000
Step    1 | Loss: 15.0783 | Grad norm: 210.0000
Step    2 | Loss: 14.4346 | Grad norm: 492.0000
Step    3 | Loss: 14.1189 | Grad norm: 54.0000
Step    4 | Loss: 12.4449 | Grad norm: 29.2500
Step    5 | Loss: 11.5218 | Grad norm: 62.7500
Step    6 | Loss: 11.0154 | Grad norm: 18.6250
Step    7 | Loss: 9.4406 | Grad norm: 7.5312
Step    8 | Loss: 9.1040 | Grad norm: 7.5938
Step    9 | Loss: 8.8089 | Grad norm: 23.1250
Step   10 | Loss: 8.3567 | Grad norm: 6.9375
Step   11 | Loss: 8.5408 | Grad norm: 40.7500
Step   12 | Loss: 8.3762 | Grad norm: 7.8750
Step   13 | Loss: 7.8817 | Grad norm: 4.9688
Step   14 | Loss: 7.9886 | Grad norm: 5.4688
Step   15 | Loss: 8.0018 | Grad norm: 4.3750
Step   16 | Loss: 8.1857 | Grad norm: 2.6250
Step   17 | Loss: 7.7257 | Grad norm: 3.5625
Step   18 | Loss: 8.3246 | Grad norm: 7.2500
Step   19 | Loss: 7.7896 | Grad norm: 5.8750
Step   20 | Loss: 8.4081 | Grad norm: 5.1875
Step   21 | Loss: 7.1414 | Grad norm: 4.3125
Step   22 | Loss: 7.7173 | Grad norm: 4.2188
Step   23 | Loss: 7.7260 | Grad norm: 3.1719
Step   24 | Loss: 7.4410 | Grad norm: 3.9844
Step   25 | Loss: 8.1295 | Grad norm: 2.7812
Step   26 | Loss: 7.6850 | Grad norm: 3.0938
Step   27 | Loss: 7.9302 | Grad norm: 4.3438
Step   28 | Loss: 7.9948 | Grad norm: 5.4062
Step   29 | Loss: 7.9131 | Grad norm: 5.1562
Step   30 | Loss: 7.9130 | Grad norm: 7.0625
Step   31 | Loss: 7.7614 | Grad norm: 3.7656
Step   32 | Loss: 7.6384 | Grad norm: 2.9062
Step   33 | Loss: 7.9990 | Grad norm: 2.2188
Step   34 | Loss: 7.9196 | Grad norm: 2.7188
Step   35 | Loss: 7.4915 | Grad norm: 2.3750
Step   36 | Loss: 7.6832 | Grad norm: 2.2500
Step   37 | Loss: 7.6955 | Grad norm: 5.6875
Step   38 | Loss: 7.8350 | Grad norm: 6.4375
Step   39 | Loss: 7.9663 | Grad norm: 7.0312
Step   40 | Loss: 7.5218 | Grad norm: 6.0000
Step   41 | Loss: 7.6298 | Grad norm: 3.6562
Step   42 | Loss: 7.3585 | Grad norm: 3.6562
Step   43 | Loss: 7.6517 | Grad norm: 2.5938
Step   44 | Loss: 7.4640 | Grad norm: 3.1562
Step   45 | Loss: 7.5150 | Grad norm: 2.8594
Step   46 | Loss: 7.7882 | Grad norm: 3.3750
Step   47 | Loss: 7.4193 | Grad norm: 2.6250
Step   48 | Loss: 7.5599 | Grad norm: 2.6719
Step   49 | Loss: 7.8967 | Grad norm: 2.7969
/fsx/ferdinandmom/ferdinand-hf/transformers_pr/env_refactor-tp-dtensor/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py:448: UserWarning: Detected an existing checkpoint in checkpoints/optimizer, overwriting since self.overwrite=True. Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to maintain this functionality or False to raise when an existing checkpoint is found.
  local_plan = storage_writer.prepare_local_plan(local_plan)
/fsx/ferdinandmom/ferdinand-hf/transformers_pr/env_refactor-tp-dtensor/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py:448: UserWarning: Detected an existing checkpoint in checkpoints/optimizer, overwriting since self.overwrite=True. Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to maintain this functionality or False to raise when an existing checkpoint is found.
  local_plan = storage_writer.prepare_local_plan(local_plan)
/fsx/ferdinandmom/ferdinand-hf/transformers_pr/env_refactor-tp-dtensor/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py:448: UserWarning: Detected an existing checkpoint in checkpoints/optimizer, overwriting since self.overwrite=True. Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to maintain this functionality or False to raise when an existing checkpoint is found.
  local_plan = storage_writer.prepare_local_plan(local_plan)
/fsx/ferdinandmom/ferdinand-hf/transformers_pr/env_refactor-tp-dtensor/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py:448: UserWarning: Detected an existing checkpoint in checkpoints/optimizer, overwriting since self.overwrite=True. Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to maintain this functionality or False to raise when an existing checkpoint is found.
  local_plan = storage_writer.prepare_local_plan(local_plan)
Saved to ./checkpoints

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille force-pushed the refactor-tp-dtensor branch 2 times, most recently from fcea5ce to f98e208 Compare April 4, 2026 16:53
Comment thread src/transformers/models/qwen3/modeling_qwen3.py Outdated
- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests
@3outeille 3outeille force-pushed the fsdp-core-model-loading branch from 607cc11 to 739332c Compare April 13, 2026 14:14
- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests
@3outeille 3outeille force-pushed the fsdp-core-model-loading branch from c567240 to c1dab9e Compare April 14, 2026 13:44
@3outeille 3outeille force-pushed the refactor-tp-dtensor branch from eb428cc to e0c4e06 Compare April 14, 2026 13:44
3outeille and others added 7 commits April 14, 2026 14:22
* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR #44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread src/transformers/integrations/__init__.py Outdated
return

# Filter out module-level comm hooks — they don't shard weights
_NON_WEIGHT_KINDS = {"activation", "module"}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe separate tp and sp style ?

Comment thread src/transformers/integrations/tensor_parallel.py Outdated
Comment thread src/transformers/models/llama/modeling_llama.py Outdated
Comment thread src/transformers/models/mistral/modeling_mistral.py Outdated
Comment thread src/transformers/modeling_utils.py
Comment thread tests/utils/test_core_model_loading.py
Comment thread src/transformers/integrations/mxfp4.py
Comment thread src/transformers/integrations/mxfp4.py
Comment thread src/transformers/models/afmoe/modeling_afmoe.py Outdated
3outeille and others added 2 commits April 14, 2026 17:05
Restores modeling files to their base branch versions so the PR diff
only shows the distributed/patches.py monkey-patch approach instead of
noisy function moves in modeling files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
  objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
  with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
  DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
  "colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter
Comment thread src/transformers/models/afmoe/modeling_afmoe.py Outdated
Comment thread src/transformers/models/apertus/configuration_apertus.py Outdated
return model

if isinstance(fsdp_plan, str):
if fsdp_plan == "auto":
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define fsdp_plan in every model and remove auto code path (will be done in the next PR)

@3outeille 3outeille changed the base branch from fsdp-core-model-loading to fsdp-vs-ddp April 20, 2026 09:19
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay reviewed only core model loading

model=model,
missing_keys=loading_info.missing_keys if loading_info else None,
)
if len(collected_tensors) > 1 and model is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) this does not make sense to me

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point, only params that need quantizatoin have self,.quantization operation attached to them

return _job


class DtensorShardOperation:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be in TP or sharding utils not here


(b) One expert — source.ndim == param.ndim - 1
MoE models stack experts along a leading axis (E, ...) in the
model, but checkpoints store each expert in its own file
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all checkpoint store them this way.

Comment on lines +892 to +896
(b) + (c) co-occurring
MoE checkpoint that is both per-expert and pre-pack (e.g.
`experts.2.w1.weight`). Resolve the expert axis first (b); the
generic loop then handles the remaining dims with (c) behavior.
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay this is a great start but :

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want each case to be defined by a different string / by checking the operations that are gonna be applied to the layer, not by having a very general / fit all approach!

2) Multi-interval on one dim — read each piece, concat on that dim:
source shape = [8, 4]
intervals = [[(0, 2), (4, 6)], [(0, 4)]]
→ cat([source[0:2, 0:4], source[4:6, 0:4]], dim=0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't think we need to cat....

out = torch.empty((4, 4), dtype=source.dtype, device=source.device)
out[0:2] = source[0:2, 0:4]
out[2:4] = source[4:6, 0:4]

worst case,

source[[0,1,4,5], 0:4]

does the same as well...

else:
expected_shape = ref.shape

# When a WeightConverter produces the full global tensor, slice it to the local DTensor shard.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm I don't understand, how can that happen>

return renamed_key, source_pattern


def concretize_target_patterns(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind me why we are touching this?

Comment on lines +1502 to +1510
empty_param = meta_model_state_dict[renamed_key]
try:
empty_param = model.get_parameter_or_buffer(renamed_key)
except (AttributeError, KeyError):
if getattr(model, "_is_fsdp_managed_module", False):
raise RuntimeError(
f"FSDP shard-on-read requires the live parameter for {renamed_key!r}, "
f"but get_parameter_or_buffer() failed."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anything fsdp specific should avoid taking space / being directly here


def _job():
return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)
def _strided_intervals(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for strided, the most efficient is prob:

  1. have a staging buffer, copy slice to it
  2. copy twice (since the rows are in cache)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll check this with @McPatate as well potentially but let's do what we can on our side!

pieces_read.append(source[tuple(piece_slices)])
return torch.cat(pieces_read, dim=multi_interval_dim).to(device=device, dtype=dtype)

def _owns_expert(self, expert_idx: int) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not make sense to have this here its specific

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, apertus, arcee, aria, audioflamingo3, bamba, bitnet

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45028&sha=5b336b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants