From 7c843391b762cfe7c36882980992fcd557cac672 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Wed, 25 Mar 2026 09:09:34 +0000 Subject: [PATCH 1/3] init --- tmp.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 tmp.py diff --git a/tmp.py b/tmp.py new file mode 100644 index 000000000000..05db1b840b44 --- /dev/null +++ b/tmp.py @@ -0,0 +1 @@ +"&&" From c33873e71c3225d233a0415ad10aa2313e032627 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Mon, 13 Apr 2026 15:36:43 +0000 Subject: [PATCH 2/3] Add distributed training scripts - train_fsdp_tp.py: minimal FSDP+TP training example - train_fsdp_tp_torchtitan_style.py: torchtitan-style training example - verify_loading.py: save/load roundtrip verification - run_compare.sh: FSDP+TP vs FSDP-only comparison - run_verify_all.sh: run verification across all modes - tmp_generate.py: quick generation test --- run_compare.sh | 56 +++++++ run_verify_all.sh | 160 ++++++++++++++++++++ tmp_generate.py | 63 ++++++++ train_fsdp_tp.py | 125 ++++++++++++++++ train_fsdp_tp_torchtitan_style.py | 239 ++++++++++++++++++++++++++++++ verify_loading.py | 137 +++++++++++++++++ 6 files changed, 780 insertions(+) create mode 100644 run_compare.sh create mode 100644 run_verify_all.sh create mode 100644 tmp_generate.py create mode 100644 train_fsdp_tp.py create mode 100644 train_fsdp_tp_torchtitan_style.py create mode 100644 verify_loading.py diff --git a/run_compare.sh b/run_compare.sh new file mode 100644 index 000000000000..eb47e1841fa9 --- /dev/null +++ b/run_compare.sh @@ -0,0 +1,56 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT="train_fsdp_tp.py" +LOG_FSDP_TP="log.txt" +LOG_FSDP_ONLY="ref.txt" + +MODEL_NAME="${MODEL_NAME:-hf-internal-testing/tiny-random-MixtralForCausalLM}" +COMMON_ARGS="--model_name $MODEL_NAME --lr 3e-4 --seed 42" + +rm -rf ./checkpoints_tp ./checkpoints_tp_resumed ./checkpoints_fsdp ./checkpoints_fsdp_resumed + +echo "=== Phase 1: Train steps 0-9, save checkpoint ===" +echo "--- Launching FSDP+TP and FSDP-only in parallel ---" + +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29500 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 --tp_size 2 --enable_sp \ + --num_steps 10 --save_dir ./checkpoints_tp > "${LOG_FSDP_TP}.phase1" 2>&1 & +PID1=$! + +CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29501 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 \ + --num_steps 10 --save_dir ./checkpoints_fsdp > "${LOG_FSDP_ONLY}.phase1" 2>&1 & +PID2=$! + +echo "FSDP+TP PID=$PID1 | FSDP-only PID=$PID2" +wait $PID1 && echo "Phase 1 FSDP+TP done" || { echo "Phase 1 FSDP+TP failed (exit $?)"; cat "${LOG_FSDP_TP}.phase1"; exit 1; } +wait $PID2 && echo "Phase 1 FSDP-only done" || { echo "Phase 1 FSDP-only failed (exit $?)"; cat "${LOG_FSDP_ONLY}.phase1"; exit 1; } + +echo "" +echo "=== Phase 2: Resume from checkpoint, train steps 10-19, save ===" +echo "--- Launching FSDP+TP and FSDP-only in parallel ---" + +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29500 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 --tp_size 2 --enable_sp \ + --num_steps 10 --start_step 10 \ + --resume_dir ./checkpoints_tp --save_dir ./checkpoints_tp_resumed > "${LOG_FSDP_TP}.phase2" 2>&1 & +PID1=$! + +CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29501 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 \ + --num_steps 10 --start_step 10 \ + --resume_dir ./checkpoints_fsdp --save_dir ./checkpoints_fsdp_resumed > "${LOG_FSDP_ONLY}.phase2" 2>&1 & +PID2=$! + +echo "FSDP+TP PID=$PID1 | FSDP-only PID=$PID2" +wait $PID1 && echo "Phase 2 FSDP+TP done" || { echo "Phase 2 FSDP+TP failed (exit $?)"; cat "${LOG_FSDP_TP}.phase2"; exit 1; } +wait $PID2 && echo "Phase 2 FSDP-only done" || { echo "Phase 2 FSDP-only failed (exit $?)"; cat "${LOG_FSDP_ONLY}.phase2"; exit 1; } + +# Combine phase logs +cat "${LOG_FSDP_TP}.phase1" "${LOG_FSDP_TP}.phase2" > "$LOG_FSDP_TP" +cat "${LOG_FSDP_ONLY}.phase1" "${LOG_FSDP_ONLY}.phase2" > "$LOG_FSDP_ONLY" + +echo "" +echo "=== Full Loss & Grad Diff (steps 0-19) ===" +git diff --no-index --color --word-diff=color "$LOG_FSDP_TP" "$LOG_FSDP_ONLY" || true diff --git a/run_verify_all.sh b/run_verify_all.sh new file mode 100644 index 000000000000..16aa3267fe9a --- /dev/null +++ b/run_verify_all.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +GREEN='\033[0;32m' +RED='\033[0;31m' +CYAN='\033[0;36m' +YELLOW='\033[1;33m' +BOLD='\033[1m' +DIM='\033[0;90m' +NC='\033[0m' + +SCRIPT="verify_loading.py" +LOGDIR="$(dirname "$0")/verify_logs" +mkdir -p "$LOGDIR" + +NUM_GPUS=$(nvidia-smi -L | wc -l) + +# Job definitions: "mode nproc_per_node" +declare -a JOBS=( + "single_gpu 1" + "fsdp 2" + "tp 2" + "tp_sp 2" + "tp_fsdp 4" + "tp_sp_fsdp 4" +) +MODE_NAMES=(single_gpu fsdp tp tp_sp tp_fsdp tp_sp_fsdp) + +echo -e "${BOLD}==========================================" +echo -e " Verify Loading (${NUM_GPUS} GPUs available)" +echo -e " Modes: ${MODE_NAMES[*]}" +echo -e " Logs: $LOGDIR/" +echo -e "==========================================${NC}" +echo "" + +# ============================================================ +# Round-robin GPU scheduler +# ============================================================ +NEXT_GPU=0 +MASTER_PORT=29500 +PIDS=() +PID_MODES=() + +for job in "${JOBS[@]}"; do + mode=${job% *} + nproc=${job#* } + + # Wait if not enough GPUs left in this round + if [ $((NEXT_GPU + nproc)) -gt "$NUM_GPUS" ]; then + echo -e "${DIM} (waiting for current round to finish...)${NC}" + for pid in "${PIDS[@]}"; do + wait "$pid" 2>/dev/null + done + PIDS=() + NEXT_GPU=0 + fi + + # Build CUDA_VISIBLE_DEVICES range + GPU_END=$((NEXT_GPU + nproc - 1)) + GPUS="" + for g in $(seq "$NEXT_GPU" "$GPU_END"); do + [ -n "$GPUS" ] && GPUS="${GPUS}," + GPUS="${GPUS}${g}" + done + + echo -e " ${CYAN}[${mode}]${NC} GPUs ${NEXT_GPU}-${GPU_END} (nproc=${nproc})" + + if [ "$nproc" -eq 1 ]; then + CUDA_VISIBLE_DEVICES="$GPUS" python "$SCRIPT" --mode "$mode" \ + > "$LOGDIR/${mode}.log" 2>&1 & + else + CUDA_VISIBLE_DEVICES="$GPUS" torchrun \ + --nproc_per_node="$nproc" --master_port="$MASTER_PORT" \ + "$SCRIPT" --mode "$mode" \ + > "$LOGDIR/${mode}.log" 2>&1 & + ((MASTER_PORT++)) + fi + + PIDS+=($!) + PID_MODES+=("$mode") + NEXT_GPU=$((GPU_END + 1)) +done + +# Wait for remaining jobs +echo "" +echo -e "${BOLD}Waiting for all jobs to finish...${NC}" +for i in "${!PIDS[@]}"; do + mode="${PID_MODES[$i]}" + if wait "${PIDS[$i]}"; then + echo -e " ${GREEN}✓${NC} ${mode}" + else + echo -e " ${RED}✗${NC} ${mode} (exit $?)" + fi +done + +# ============================================================ +# Results +# ============================================================ +echo "" +echo -e "${BOLD}=== Results ===${NC}" +for mode in "${MODE_NAMES[@]}"; do + log="$LOGDIR/$mode.log" + loss_before=$(grep -oP 'loss_before = \K[0-9.]+' "$log" 2>/dev/null) + loss_after=$(grep -oP 'loss_after = \K[0-9.]+' "$log" 2>/dev/null) + if grep -q '^PASS' "$log" 2>/dev/null; then + printf " ${GREEN}%-12s PASS (before=%-10s after=%s)${NC}\n" "$mode" "$loss_before" "$loss_after" + elif [ -n "$loss_before" ]; then + diff=$(grep -oP 'diff = \K[0-9.e+-]+' "$log" 2>/dev/null) + printf " ${RED}%-12s FAIL (before=%-10s after=%-10s diff=%s)${NC}\n" "$mode" "$loss_before" "$loss_after" "$diff" + else + printf " ${RED}%-12s ERROR (see log)${NC}\n" "$mode" + fi +done + +# ============================================================ +# Cross-mode loss comparison +# ============================================================ +echo "" +echo -e "${BOLD}=== Cross-mode loss comparison (PASS modes only) ===${NC}" +REF_LOSS="" +ALL_MATCH=1 +for mode in "${MODE_NAMES[@]}"; do + log="$LOGDIR/$mode.log" + # Only include modes where save/load roundtrip passed + if ! grep -q '^PASS' "$log" 2>/dev/null; then + continue + fi + loss=$(grep -oP 'loss_before = \K[0-9.]+' "$log" 2>/dev/null) + if [ -z "$loss" ]; then + continue + fi + if [ -z "$REF_LOSS" ]; then + REF_LOSS="$loss" + printf " ${GREEN}%-12s %s (reference)${NC}\n" "$mode" "$loss" + elif [ "$loss" = "$REF_LOSS" ]; then + printf " ${GREEN}%-12s %s${NC}\n" "$mode" "$loss" + else + printf " ${YELLOW}%-12s %s (differs from %s)${NC}\n" "$mode" "$loss" "$REF_LOSS" + ALL_MATCH=0 + fi +done +if [ "$ALL_MATCH" -eq 1 ] && [ -n "$REF_LOSS" ]; then + echo -e " ${GREEN}All modes produce the same loss.${NC}" +fi + +# Hints for failures +HAS_FAIL=0 +for mode in "${MODE_NAMES[@]}"; do + if ! grep -q '^PASS' "$LOGDIR/$mode.log" 2>/dev/null; then + HAS_FAIL=1 + fi +done +if [ "$HAS_FAIL" -eq 1 ]; then + echo "" + echo -e "${YELLOW}Some modes failed. Check logs:${NC}" + for mode in "${MODE_NAMES[@]}"; do + if ! grep -q '^PASS' "$LOGDIR/$mode.log" 2>/dev/null; then + echo -e " ${YELLOW}cat $LOGDIR/$mode.log${NC}" + fi + done +fi diff --git a/tmp_generate.py b/tmp_generate.py new file mode 100644 index 000000000000..9685bed643ed --- /dev/null +++ b/tmp_generate.py @@ -0,0 +1,63 @@ +import argparse +import os + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig + +model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" +# model_id = "Qwen/Qwen3-14B" +# model_id = "Qwen/Qwen3-0.6B" +# model_id = "Qwen/Qwen1.5-MoE-A2.7B-Chat" +# model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +device = torch.device(f"cuda:{rank}") +# Need to be initialized explicitly to use the `barrier` before loading +torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, device_id=rank) + +@record +def main(args): + + distributed_config = DistributedConfig(tp_size=4, tp_plan="auto") + model = AutoModelForCausalLM.from_pretrained(model_id, distributed_config=distributed_config, dtype=torch.bfloat16) + # model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + messages = [ + {"role": "user", "content": "What do you think about life?"}, + ] + inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device) + input_size = inputs.input_ids.shape[-1] + + if args.profile: + # Warmup + with torch.no_grad(): + _ = model.generate(**inputs, max_new_tokens=5, do_sample=False) + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + output = model.generate(**inputs, max_new_tokens=2, do_sample=False) + + if rank == 0: + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) + prof.export_chrome_trace("trace.json") + else: + output = model.generate(**inputs, max_new_tokens=100, do_sample=False) + + text = tokenizer.batch_decode(output[:, input_size:])[0] + if rank == 0: + print(text) + +parser = argparse.ArgumentParser() +parser.add_argument("--profile", action="store_true") +args = parser.parse_args() + +main(args) + +torch.distributed.destroy_process_group() \ No newline at end of file diff --git a/train_fsdp_tp.py b/train_fsdp_tp.py new file mode 100644 index 000000000000..0232f8b3bc3d --- /dev/null +++ b/train_fsdp_tp.py @@ -0,0 +1,125 @@ +# torchrun --nproc_per_node=4 train_fsdp_tp.py + +import argparse +import os + +import torch +from datasets import load_dataset +from torch.distributed.tensor import DTensor +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig +from transformers.distributed.utils import load_optimizer, save_optimizer + +def build_packed_dataset(dataset_name, tokenizer, seq_len, dp_rank, dp_world_size): + """Stream + tokenize + greedy-pack documents into fixed-length (input, label) windows.""" + ds = load_dataset(dataset_name, name="en", split="train", streaming=True) + ds = ds.shard(num_shards=dp_world_size, index=dp_rank) + buf, w = [], seq_len + 1 + + def pack(batch): + for t in batch["text"]: + buf.extend(tokenizer(t)["input_ids"]) + ids, lbls = [], [] + while len(buf) >= w: + ids.append(buf[:seq_len]); lbls.append(buf[1:w]); del buf[:w] + return {"input_ids": ids, "labels": lbls} + + ds = ds.map(pack, batched=True, remove_columns=ds.column_names) + return ds.with_format("torch") + +def build_fixed_batches(dp_rank): + """Load pre-generated fixed batches for a given DP rank.""" + return torch.load(f"fixed_batches_dp{dp_rank}.pt", weights_only=True) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-0.6B") + parser.add_argument("--num_steps", type=int, default=20) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--seq_len", type=int, default=512) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--save_dir", type=str, default="./checkpoints") + parser.add_argument("--tp_size", type=int, default=0, help="Tensor parallel size (0 = disabled)") + parser.add_argument("--fsdp_size", type=int, default=0, help="FSDP size (0 = disabled)") + parser.add_argument("--enable_sp", action="store_true", help="Enable sequence parallelism") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--fixed_batches", action="store_true", help="Use pre-generated fixed batches instead of C4") + parser.add_argument("--resume_dir", type=str, default=None, help="Resume from this checkpoint directory") + parser.add_argument("--start_step", type=int, default=0, help="Starting step number (for logging)") + args = parser.parse_args() + + torch.distributed.init_process_group(backend="nccl") + rank, local_rank = int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + torch.manual_seed(args.seed) + + dc_kwargs = {} + if args.tp_size > 0: + dc_kwargs["tp_size"] = args.tp_size + dc_kwargs["tp_plan"] = "auto" + if args.fsdp_size > 0: + dc_kwargs["fsdp_size"] = args.fsdp_size + dc_kwargs["fsdp_plan"] = "auto" + if args.enable_sp: + dc_kwargs["enable_sequence_parallel"] = True + distributed_config = DistributedConfig(**dc_kwargs) + + load_path = args.resume_dir if args.resume_dir else args.model_name + model = AutoModelForCausalLM.from_pretrained( + load_path, + distributed_config=distributed_config, + torch_dtype=torch.bfloat16, + ) + + dp_rank = model.device_mesh["fsdp"].get_local_rank() if "fsdp" in model.device_mesh.mesh_dim_names else 0 + dp_size = model.device_mesh["fsdp"].size() if "fsdp" in model.device_mesh.mesh_dim_names else 1 + + if args.fixed_batches: + fixed = build_fixed_batches(dp_rank) + else: + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + dataset = build_packed_dataset("allenai/c4", tokenizer, args.seq_len, dp_rank, dp_size) + dataloader = iter(DataLoader(dataset, batch_size=args.batch_size)) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + + if args.resume_dir: + load_optimizer(optimizer, os.path.join(args.resume_dir, "optimizer")) + if rank == 0: + print(f"Resumed optimizer from {args.resume_dir}") + + model.train() + for step in range(args.start_step, args.start_step + args.num_steps): + if args.fixed_batches: + input_ids = fixed[step]["input_ids"].to(f"cuda:{local_rank}") + labels = fixed[step]["labels"].to(f"cuda:{local_rank}") + else: + batch = next(dataloader) + input_ids = batch["input_ids"].to(f"cuda:{local_rank}") + labels = batch["labels"].to(f"cuda:{local_rank}") + loss = model(input_ids, labels=labels).loss + loss.backward() + + # Custom grad clip: convert DTensor grads to local to avoid mixed-mesh torch.stack + grads = [p.grad for p in model.parameters() if p.grad is not None] + local_grads = [g.full_tensor() if isinstance(g, DTensor) else g for g in grads] + total_norm = torch.nn.utils.get_total_norm(local_grads, norm_type=2.0) + torch.nn.utils.clip_grads_with_norm_(grads, max_norm=1.0, total_norm=total_norm) + optimizer.step() + optimizer.zero_grad() + + if rank == 0: + print(f"Step {step:>4d} | Loss: {loss.item():.4f} | Grad norm: {total_norm.item():.4f}") + + # Save model (HF format) and optimizer (DCP) + model.save_pretrained(args.save_dir) + save_optimizer(optimizer, os.path.join(args.save_dir, "optimizer")) + + if rank == 0: + print(f"Saved to {args.save_dir}") + + torch.distributed.destroy_process_group() diff --git a/train_fsdp_tp_torchtitan_style.py b/train_fsdp_tp_torchtitan_style.py new file mode 100644 index 000000000000..325ee112c778 --- /dev/null +++ b/train_fsdp_tp_torchtitan_style.py @@ -0,0 +1,239 @@ +# torchrun --nproc_per_node=4 train_fsdp_tp_torchtitan_style.py +# LOAD_PRETRAINED=1 torchrun --nproc_per_node=4 train_fsdp_tp_torchtitan_style.py +# +# Minimal standalone training script that reuses torchtitan's components +# (model wrapper, parallelization, loss, optimizer, grad clipping) directly. +# This is the same code path as `./run_train.sh` but without the config system. + +import os + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import torch.nn.functional as F +from huggingface_hub import snapshot_download +from torch.distributed.checkpoint import HuggingFaceStorageReader + +# ---------- torchtitan imports ---------- +from torchtitan.distributed import ParallelDims +from torchtitan.distributed import utils as dist_utils +from torchtitan.experiments.transformers_modeling_backend.infra.parallelize import ( + apply_fsdp, + apply_non_moe_tp, + disable_fsdp_gradient_division, +) +from torchtitan.experiments.transformers_modeling_backend.model.args import ( + HFTransformerModelArgs, + TitanDenseModelArgs, +) +from torchtitan.experiments.transformers_modeling_backend.model.model import ( + HFTransformerModel, +) + +# ---------- transformers imports ---------- +from transformers import AutoConfig, AutoTokenizer + +IGNORE_INDEX = -100 + + +def build_model_args(hf_model_name: str, seq_len: int) -> HFTransformerModelArgs: + """Build HFTransformerModelArgs from a HuggingFace model name.""" + hf_config = AutoConfig.from_pretrained( + hf_model_name, attn_implementation="sdpa", trust_remote_code=True + ) + hf_config_dict = hf_config.to_dict() + + model_args = HFTransformerModelArgs(titan_dense_args=TitanDenseModelArgs()) + + # Map TorchTitan attr names → HF attr names + for titan_name, hf_name in model_args._tt_to_hf_attribute_map.items(): + if hasattr(hf_config, hf_name): + setattr(model_args, titan_name, getattr(hf_config, hf_name)) + + # Copy all HF config attributes + for key, value in hf_config_dict.items(): + setattr(model_args, key, value) + + # Override with training-specific settings + model_args.max_seq_len = seq_len + model_args.deterministic = False + model_args.attention_bias = False + model_args.mlp_bias = False + model_args.use_cache = False + model_args.initializer_range = 1.0 + model_args.pruned_heads = getattr(hf_config, "pruned_heads", {}) + + if "head_dim" not in hf_config_dict: + model_args.head_dim = model_args.dim // model_args.num_attention_heads + + return model_args + + +if __name__ == "__main__": + # ── Config ────────────────────────────────────────────────────────── + model_name = "Qwen/Qwen3-0.6B" + seq_len = 512 + num_steps = 50 + lr = 3e-4 + max_norm = 1.0 + tp_degree = 2 + dp_degree = 2 # FSDP shard degree + batch_size = 4 + + # ── Distributed init ──────────────────────────────────────────────── + dist.init_process_group(backend="nccl") + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + parallel_dims = ParallelDims( + dp_shard=dp_degree, + dp_replicate=1, + tp=tp_degree, + pp=1, + ep=1, + etp=1, + cp=1, + world_size=world_size, + ) + world_mesh = parallel_dims.build_mesh() + + # ── C4 dataset (same as torchtitan) ───────────────────────────────── + from torchtitan.hf_datasets.text_datasets import build_text_dataloader + from torchtitan.components.tokenizer import build_hf_tokenizer + from torchtitan.config.job_config import JobConfig as TTJobConfig + from types import SimpleNamespace + + tt_tokenizer = build_hf_tokenizer( + SimpleNamespace( + model=SimpleNamespace( + hf_assets_path=snapshot_download(model_name), + name="transformers_modeling_backend", + tokenizer_path="", + ) + ) + ) + dp_rank = parallel_dims.get_mesh("fsdp").get_local_rank() + dp_world_size = parallel_dims.get_mesh("fsdp").size() + tt_job_config = TTJobConfig() + tt_job_config.training.dataset = "c4" + tt_job_config.training.dataset_path = None + tt_job_config.training.local_batch_size = batch_size + tt_job_config.training.seq_len = seq_len + dataloader = build_text_dataloader( + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tt_tokenizer, + job_config=tt_job_config, + infinite=True, + ) + + # ── Model ─────────────────────────────────────────────────────────── + model_args = build_model_args(model_name, seq_len) + + with torch.device("meta"): + model = HFTransformerModel(model_args) + + # ── Parallelize (same as torchtitan's parallelize_hf_transformers) ── + tp_mesh = parallel_dims.get_mesh("tp") + apply_non_moe_tp( + model, + tp_mesh, + loss_parallel=True, # lm_head output → Shard(-1) + enable_float8_tensorwise_tp=False, + ) + + dp_mesh = parallel_dims.get_mesh("fsdp") + apply_fsdp( + model, + dp_mesh, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + pp_enabled=False, + ) + disable_fsdp_gradient_division(model) + + # ── Materialize + init weights ────────────────────────────────────── + model.to_empty(device=device) + with torch.no_grad(): + model.init_weights() + model.train() + + # ── (Optional) Load pretrained weights via DCP ────────────────────── + # Set LOAD_PRETRAINED=1 to load HF weights. Default: train from random init + # (matching what torchtitan's run_train.sh does without a checkpoint). + if os.environ.get("LOAD_PRETRAINED", "0") == "1": + checkpoint_path = snapshot_download(model_name) + state_dict = model.state_dict() + PREFIX = "model." + hf_keyed = {k[len(PREFIX):]: v for k, v in state_dict.items() if k.startswith(PREFIX)} + dcp.load(hf_keyed, storage_reader=HuggingFaceStorageReader(checkpoint_path)) + model.load_state_dict({PREFIX + k: v for k, v in hf_keyed.items()}) + if rank == 0: + print("Pretrained weights loaded via DCP.") + else: + if rank == 0: + print("Training from random init (no pretrained weights).") + + # ── Optimizer ─────────────────────────────────────────────────────── + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + fused=True, + ) + + # ── loss_parallel context (logits are Shard(-1) on TP mesh) ───────── + loss_parallel_enabled = parallel_dims.tp_enabled + train_context = dist_utils.get_train_context(loss_parallel_enabled) + + # ── Training loop ─────────────────────────────────────────────────── + data_iterator = iter(dataloader) + for step in range(num_steps): + optimizer.zero_grad() + + # torchtitan dataloader yields ({"input": input_ids}, labels) + # both of shape (batch, seq_len) — already shifted, no padding. + input_dict, labels = next(data_iterator) + input_ids = input_dict["input"].to(device) + labels = labels.to(device) + + # No padding in C4 stream — all tokens are valid + local_valid_tokens = (labels != IGNORE_INDEX).sum().to(device) + global_valid_tokens = dist_utils.dist_sum( + local_valid_tokens, parallel_dims.get_mesh("batch") + ) + + # Forward + loss under train_context (enables loss_parallel if TP) + # input_ids and labels are same length (seq_len), already shifted by dataloader. + # pred aligns directly with labels — no slicing needed. + with train_context(): + pred = model(input_ids) # (batch, seq_len, vocab) as Shard(-1) DTensor + loss_sum = F.cross_entropy( + pred.flatten(0, 1).float(), + labels.flatten(0, 1), + reduction="sum", + ignore_index=IGNORE_INDEX, + ) + loss = loss_sum / global_valid_tokens + del pred + loss.backward() + + # Gradient clipping (torchtitan's implementation) + grad_norm = dist_utils.clip_grad_norm_( + list(model.parameters()), max_norm, foreach=True + ) + + optimizer.step() + + if rank == 0: + print( + f"Step {step:>4d} | Loss: {loss.item():.4f} | " + f"Grad norm: {grad_norm.item():.4f}" + ) + + dist.destroy_process_group() diff --git a/verify_loading.py b/verify_loading.py new file mode 100644 index 000000000000..ea008f9626f7 --- /dev/null +++ b/verify_loading.py @@ -0,0 +1,137 @@ +# Save/load roundtrip test for distributed models (TP, FSDP, TP+FSDP). +# +# Verifies that save_pretrained → from_pretrained preserves model weights by +# checking that the cross-entropy loss is identical before and after the roundtrip. +# This catches bugs in DTensor gather-on-save and shard-on-read paths. +# +# Usage: +# python verify_loading.py --mode single_gpu +# torchrun --nproc_per_node=2 verify_loading.py --mode fsdp +# torchrun --nproc_per_node=2 verify_loading.py --mode tp +# torchrun --nproc_per_node=4 verify_loading.py --mode tp_fsdp +# MODEL=Qwen/Qwen3-0.6B torchrun --nproc_per_node=2 verify_loading.py --mode tp +import argparse +import os +import shutil + +import torch +from torch.distributed.tensor import DTensor, Replicate + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig + + +parser = argparse.ArgumentParser() +parser.add_argument("--mode", choices=["single_gpu", "fsdp", "tp", "tp_sp", "tp_fsdp", "tp_sp_fsdp"], required=True) +parser.add_argument("--model", type=str, default=None, help="Model ID (or set MODEL env var)") +args = parser.parse_args() + +model_id = args.model or os.environ.get("MODEL") or os.environ.get("MODEL_ID") or "hf-internal-testing/tiny-random-MixtralForCausalLM" + +if args.mode != "single_gpu": + torch.distributed.init_process_group(backend="nccl") + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) +else: + rank = 0 + local_rank = 0 + torch.cuda.set_device(0) + +configs = { + "single_gpu": None, + "fsdp": DistributedConfig(fsdp_size=2, fsdp_plan="auto"), + "tp": DistributedConfig(tp_size=2, tp_plan="auto"), + "tp_sp": DistributedConfig(tp_size=2, tp_plan="auto", enable_sequence_parallel=True), + "tp_fsdp": DistributedConfig(tp_size=2, tp_plan="auto", fsdp_size=2, fsdp_plan="auto"), + "tp_sp_fsdp": DistributedConfig(tp_size=2, tp_plan="auto", fsdp_size=2, fsdp_plan="auto", enable_sequence_parallel=True), +} + +tokenizer = AutoTokenizer.from_pretrained(model_id) +text = "The capital of France is Paris. The largest ocean is the Pacific." + + +def materialize_full_logits(logits: torch.Tensor) -> torch.Tensor: + if isinstance(logits, DTensor): + with torch.no_grad(): + return logits.redistribute(placements=[Replicate()] * logits.device_mesh.ndim, async_op=False).to_local() + return logits + + +def compute_loss(model): + inputs = tokenizer(text, return_tensors="pt").to(f"cuda:{local_rank}") + input_ids = inputs["input_ids"] + # Pad sequence length to a multiple of tp_size so DTensor Shard(1) splits evenly + # across ranks in SP mode. Always pad (even for non-TP modes) so that all modes + # compute on the same input and losses are directly comparable. + max_tp = max((c.tp_size if c is not None else 1) for c in configs.values()) + seq_len = input_ids.shape[1] + if seq_len % max_tp != 0: + pad_len = max_tp - (seq_len % max_tp) + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + input_ids = torch.cat([input_ids, input_ids.new_full((1, pad_len), pad_token_id)], dim=1) + labels = input_ids.clone() + labels[:, seq_len:] = -100 # ignore padding in loss + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0) + + model.eval() + with torch.no_grad(): + logits = model(input_ids, position_ids=position_ids).logits + logits = materialize_full_logits(logits) + loss = torch.nn.functional.cross_entropy( + logits.flatten(0, 1).float(), + labels.flatten(0, 1), + reduction="mean", + ignore_index=-100, + ) + return loss.item() + + +# --- Step 1: Load original model and compute loss --- +model = AutoModelForCausalLM.from_pretrained(model_id, distributed_config=configs[args.mode], dtype=torch.float32) +if args.mode == "single_gpu": + model = model.to("cuda:0") + +loss_before = compute_loss(model) +if rank == 0: + print(f"{args.mode}: loss_before = {loss_before:.6f}") + +# --- Step 2: Save to local dir (shared path across ranks) --- +save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"verify_ckpt_{args.mode}") +if rank == 0: + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + os.makedirs(save_dir) +if args.mode != "single_gpu": + torch.distributed.barrier() +model.save_pretrained(save_dir, is_main_process=(rank == 0)) +if rank == 0: + print(f"{args.mode}: saved to {save_dir}") + +# Ensure all ranks see the saved files before reloading +if args.mode != "single_gpu": + torch.distributed.barrier() + +del model +torch.cuda.empty_cache() + +# --- Step 3: Reload from saved checkpoint and compute loss --- +model2 = AutoModelForCausalLM.from_pretrained(save_dir, distributed_config=configs[args.mode], dtype=torch.float32) +if args.mode == "single_gpu": + model2 = model2.to("cuda:0") + +loss_after = compute_loss(model2) +if rank == 0: + print(f"{args.mode}: loss_after = {loss_after:.6f}") + +# --- Step 4: Compare --- +if rank == 0: + diff = abs(loss_before - loss_after) + print(f"{args.mode}: diff = {diff:.2e}") + if diff < 1e-5: + print("PASS: save/load roundtrip is lossless") + else: + print("FAIL: loss mismatch after save/load roundtrip!") + +if args.mode != "single_gpu": + torch.distributed.destroy_process_group() From 34db8405496c10371997e4f0de1cee93feb54ca9 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Mon, 13 Apr 2026 15:40:05 +0000 Subject: [PATCH 3/3] Remove train_fsdp_tp_torchtitan_style.py --- train_fsdp_tp_torchtitan_style.py | 239 ------------------------------ 1 file changed, 239 deletions(-) delete mode 100644 train_fsdp_tp_torchtitan_style.py diff --git a/train_fsdp_tp_torchtitan_style.py b/train_fsdp_tp_torchtitan_style.py deleted file mode 100644 index 325ee112c778..000000000000 --- a/train_fsdp_tp_torchtitan_style.py +++ /dev/null @@ -1,239 +0,0 @@ -# torchrun --nproc_per_node=4 train_fsdp_tp_torchtitan_style.py -# LOAD_PRETRAINED=1 torchrun --nproc_per_node=4 train_fsdp_tp_torchtitan_style.py -# -# Minimal standalone training script that reuses torchtitan's components -# (model wrapper, parallelization, loss, optimizer, grad clipping) directly. -# This is the same code path as `./run_train.sh` but without the config system. - -import os - -import torch -import torch.distributed as dist -import torch.distributed.checkpoint as dcp -import torch.nn.functional as F -from huggingface_hub import snapshot_download -from torch.distributed.checkpoint import HuggingFaceStorageReader - -# ---------- torchtitan imports ---------- -from torchtitan.distributed import ParallelDims -from torchtitan.distributed import utils as dist_utils -from torchtitan.experiments.transformers_modeling_backend.infra.parallelize import ( - apply_fsdp, - apply_non_moe_tp, - disable_fsdp_gradient_division, -) -from torchtitan.experiments.transformers_modeling_backend.model.args import ( - HFTransformerModelArgs, - TitanDenseModelArgs, -) -from torchtitan.experiments.transformers_modeling_backend.model.model import ( - HFTransformerModel, -) - -# ---------- transformers imports ---------- -from transformers import AutoConfig, AutoTokenizer - -IGNORE_INDEX = -100 - - -def build_model_args(hf_model_name: str, seq_len: int) -> HFTransformerModelArgs: - """Build HFTransformerModelArgs from a HuggingFace model name.""" - hf_config = AutoConfig.from_pretrained( - hf_model_name, attn_implementation="sdpa", trust_remote_code=True - ) - hf_config_dict = hf_config.to_dict() - - model_args = HFTransformerModelArgs(titan_dense_args=TitanDenseModelArgs()) - - # Map TorchTitan attr names → HF attr names - for titan_name, hf_name in model_args._tt_to_hf_attribute_map.items(): - if hasattr(hf_config, hf_name): - setattr(model_args, titan_name, getattr(hf_config, hf_name)) - - # Copy all HF config attributes - for key, value in hf_config_dict.items(): - setattr(model_args, key, value) - - # Override with training-specific settings - model_args.max_seq_len = seq_len - model_args.deterministic = False - model_args.attention_bias = False - model_args.mlp_bias = False - model_args.use_cache = False - model_args.initializer_range = 1.0 - model_args.pruned_heads = getattr(hf_config, "pruned_heads", {}) - - if "head_dim" not in hf_config_dict: - model_args.head_dim = model_args.dim // model_args.num_attention_heads - - return model_args - - -if __name__ == "__main__": - # ── Config ────────────────────────────────────────────────────────── - model_name = "Qwen/Qwen3-0.6B" - seq_len = 512 - num_steps = 50 - lr = 3e-4 - max_norm = 1.0 - tp_degree = 2 - dp_degree = 2 # FSDP shard degree - batch_size = 4 - - # ── Distributed init ──────────────────────────────────────────────── - dist.init_process_group(backend="nccl") - rank = int(os.environ["RANK"]) - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - torch.cuda.set_device(local_rank) - device = torch.device(f"cuda:{local_rank}") - - parallel_dims = ParallelDims( - dp_shard=dp_degree, - dp_replicate=1, - tp=tp_degree, - pp=1, - ep=1, - etp=1, - cp=1, - world_size=world_size, - ) - world_mesh = parallel_dims.build_mesh() - - # ── C4 dataset (same as torchtitan) ───────────────────────────────── - from torchtitan.hf_datasets.text_datasets import build_text_dataloader - from torchtitan.components.tokenizer import build_hf_tokenizer - from torchtitan.config.job_config import JobConfig as TTJobConfig - from types import SimpleNamespace - - tt_tokenizer = build_hf_tokenizer( - SimpleNamespace( - model=SimpleNamespace( - hf_assets_path=snapshot_download(model_name), - name="transformers_modeling_backend", - tokenizer_path="", - ) - ) - ) - dp_rank = parallel_dims.get_mesh("fsdp").get_local_rank() - dp_world_size = parallel_dims.get_mesh("fsdp").size() - tt_job_config = TTJobConfig() - tt_job_config.training.dataset = "c4" - tt_job_config.training.dataset_path = None - tt_job_config.training.local_batch_size = batch_size - tt_job_config.training.seq_len = seq_len - dataloader = build_text_dataloader( - dp_world_size=dp_world_size, - dp_rank=dp_rank, - tokenizer=tt_tokenizer, - job_config=tt_job_config, - infinite=True, - ) - - # ── Model ─────────────────────────────────────────────────────────── - model_args = build_model_args(model_name, seq_len) - - with torch.device("meta"): - model = HFTransformerModel(model_args) - - # ── Parallelize (same as torchtitan's parallelize_hf_transformers) ── - tp_mesh = parallel_dims.get_mesh("tp") - apply_non_moe_tp( - model, - tp_mesh, - loss_parallel=True, # lm_head output → Shard(-1) - enable_float8_tensorwise_tp=False, - ) - - dp_mesh = parallel_dims.get_mesh("fsdp") - apply_fsdp( - model, - dp_mesh, - param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - pp_enabled=False, - ) - disable_fsdp_gradient_division(model) - - # ── Materialize + init weights ────────────────────────────────────── - model.to_empty(device=device) - with torch.no_grad(): - model.init_weights() - model.train() - - # ── (Optional) Load pretrained weights via DCP ────────────────────── - # Set LOAD_PRETRAINED=1 to load HF weights. Default: train from random init - # (matching what torchtitan's run_train.sh does without a checkpoint). - if os.environ.get("LOAD_PRETRAINED", "0") == "1": - checkpoint_path = snapshot_download(model_name) - state_dict = model.state_dict() - PREFIX = "model." - hf_keyed = {k[len(PREFIX):]: v for k, v in state_dict.items() if k.startswith(PREFIX)} - dcp.load(hf_keyed, storage_reader=HuggingFaceStorageReader(checkpoint_path)) - model.load_state_dict({PREFIX + k: v for k, v in hf_keyed.items()}) - if rank == 0: - print("Pretrained weights loaded via DCP.") - else: - if rank == 0: - print("Training from random init (no pretrained weights).") - - # ── Optimizer ─────────────────────────────────────────────────────── - optimizer = torch.optim.AdamW( - model.parameters(), - lr=lr, - betas=(0.9, 0.95), - eps=1e-8, - weight_decay=0.1, - fused=True, - ) - - # ── loss_parallel context (logits are Shard(-1) on TP mesh) ───────── - loss_parallel_enabled = parallel_dims.tp_enabled - train_context = dist_utils.get_train_context(loss_parallel_enabled) - - # ── Training loop ─────────────────────────────────────────────────── - data_iterator = iter(dataloader) - for step in range(num_steps): - optimizer.zero_grad() - - # torchtitan dataloader yields ({"input": input_ids}, labels) - # both of shape (batch, seq_len) — already shifted, no padding. - input_dict, labels = next(data_iterator) - input_ids = input_dict["input"].to(device) - labels = labels.to(device) - - # No padding in C4 stream — all tokens are valid - local_valid_tokens = (labels != IGNORE_INDEX).sum().to(device) - global_valid_tokens = dist_utils.dist_sum( - local_valid_tokens, parallel_dims.get_mesh("batch") - ) - - # Forward + loss under train_context (enables loss_parallel if TP) - # input_ids and labels are same length (seq_len), already shifted by dataloader. - # pred aligns directly with labels — no slicing needed. - with train_context(): - pred = model(input_ids) # (batch, seq_len, vocab) as Shard(-1) DTensor - loss_sum = F.cross_entropy( - pred.flatten(0, 1).float(), - labels.flatten(0, 1), - reduction="sum", - ignore_index=IGNORE_INDEX, - ) - loss = loss_sum / global_valid_tokens - del pred - loss.backward() - - # Gradient clipping (torchtitan's implementation) - grad_norm = dist_utils.clip_grad_norm_( - list(model.parameters()), max_norm, foreach=True - ) - - optimizer.step() - - if rank == 0: - print( - f"Step {step:>4d} | Loss: {loss.item():.4f} | " - f"Grad norm: {grad_norm.item():.4f}" - ) - - dist.destroy_process_group()