Skip to content

🚨 Distributed training API#44989

Draft
3outeille wants to merge 8 commits intomainfrom
distributed_api
Draft

🚨 Distributed training API#44989
3outeille wants to merge 8 commits intomainfrom
distributed_api

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Mar 25, 2026

Distributed Training API

Goal

# torchrun --nproc_per_node=4 train_fsdp_tp.py

import os
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.distributed import DistributedConfig
from transformers.distributed.utils import save_optimizer

def build_packed_dataset(dataset_name, tokenizer, seq_len, dp_rank, dp_world_size):
    """Stream + tokenize + greedy-pack documents into fixed-length (input, label) windows."""
    ds = load_dataset(dataset_name, name="en", split="train", streaming=True)
    ds = ds.shard(num_shards=dp_world_size, index=dp_rank)
    buf, w = [], seq_len + 1

    def pack(batch):
        for t in batch["text"]:
            buf.extend(tokenizer(t)["input_ids"])
        ids, lbls = [], []
        while len(buf) >= w:
            ids.append(buf[:seq_len]); lbls.append(buf[1:w]); del buf[:w]
        return {"input_ids": ids, "labels": lbls}

    ds = ds.map(pack, batched=True, remove_columns=ds.column_names)
    return ds.with_format("torch")

if __name__ == "__main__":

    model_name = "Isotonic/TinyMixtral-4x248M-MoE"
    num_steps, lr = 50, 3e-4
    save_dir = "./checkpoints"

    torch.distributed.init_process_group(backend="nccl")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        distributed_config=DistributedConfig(tp_size=2, fsdp_size=2),
        torch_dtype=torch.bfloat16,
    )

    rank = torch.distributed.get_rank()
    dp_rank = model.device_mesh["fsdp"].get_local_rank()
    dp_world_size = model.device_mesh["fsdp"].size()
    
    dataset = build_packed_dataset("allenai/c4", tokenizer, 512, dp_rank=dp_rank, dp_world_size=dp_world_size)
    dataloader = iter(DataLoader(dataset, batch_size=4))

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    model.train()
    for step, batch in enumerate(dataloader):
        if step >= num_steps:
            break
        input_ids = batch["input_ids"].to(f"cuda:{dp_rank}")
        labels = input_ids.clone()
        labels[labels == tokenizer.pad_token_id] = -100

        loss = model(input_ids, labels=labels).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if rank == 0 and step % 10 == 0:
            print(f"Step {step:>4d} | Loss: {loss.item():.4f}")

    model.save_pretrained(save_dir)
    save_optimizer(optimizer, os.path.join(save_dir, "optimizer"))
    if rank == 0:
        tokenizer.save_pretrained(save_dir)
        print(f"Saved to {save_dir}")

    torch.distributed.destroy_process_group()

PR Chain

Review order PR Branch Content
1st #45409 orchestration-save-loadmoe-sequence-parallel from_pretrained orchestration, gather_full_state_dict(), save/load roundtrip
2nd #45408 moe-sequence-parallelrefactor-tp-dtensor PackedColwiseParallel, MoEExpertsParallel, sequence parallelism, MoE configs (mixtral, deepseek_v3, qwen3)
3rd #45028 refactor-tp-dtensorfsdp-core-model-loading TPStyle API, apply_tensor_parallel(), dense model configs (llama, mistral, qwen2, phi, glm)
4th #44974 fsdp-core-model-loadingfsdp-vs-ddp DistributedConfig, DtensorShardOperation, shard-on-read loading
5th #44083 fsdp-vs-ddpdistributed_api FSDP2 fully_shard integration, auto/manual mode, FSDP vs DDP parity tests

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44989&sha=69bc48

3outeille and others added 4 commits April 13, 2026 16:34
- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants