diff --git a/run_compare.sh b/run_compare.sh new file mode 100644 index 000000000000..eb47e1841fa9 --- /dev/null +++ b/run_compare.sh @@ -0,0 +1,56 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT="train_fsdp_tp.py" +LOG_FSDP_TP="log.txt" +LOG_FSDP_ONLY="ref.txt" + +MODEL_NAME="${MODEL_NAME:-hf-internal-testing/tiny-random-MixtralForCausalLM}" +COMMON_ARGS="--model_name $MODEL_NAME --lr 3e-4 --seed 42" + +rm -rf ./checkpoints_tp ./checkpoints_tp_resumed ./checkpoints_fsdp ./checkpoints_fsdp_resumed + +echo "=== Phase 1: Train steps 0-9, save checkpoint ===" +echo "--- Launching FSDP+TP and FSDP-only in parallel ---" + +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29500 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 --tp_size 2 --enable_sp \ + --num_steps 10 --save_dir ./checkpoints_tp > "${LOG_FSDP_TP}.phase1" 2>&1 & +PID1=$! + +CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29501 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 \ + --num_steps 10 --save_dir ./checkpoints_fsdp > "${LOG_FSDP_ONLY}.phase1" 2>&1 & +PID2=$! + +echo "FSDP+TP PID=$PID1 | FSDP-only PID=$PID2" +wait $PID1 && echo "Phase 1 FSDP+TP done" || { echo "Phase 1 FSDP+TP failed (exit $?)"; cat "${LOG_FSDP_TP}.phase1"; exit 1; } +wait $PID2 && echo "Phase 1 FSDP-only done" || { echo "Phase 1 FSDP-only failed (exit $?)"; cat "${LOG_FSDP_ONLY}.phase1"; exit 1; } + +echo "" +echo "=== Phase 2: Resume from checkpoint, train steps 10-19, save ===" +echo "--- Launching FSDP+TP and FSDP-only in parallel ---" + +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29500 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 --tp_size 2 --enable_sp \ + --num_steps 10 --start_step 10 \ + --resume_dir ./checkpoints_tp --save_dir ./checkpoints_tp_resumed > "${LOG_FSDP_TP}.phase2" 2>&1 & +PID1=$! + +CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29501 \ + $SCRIPT $COMMON_ARGS --fsdp_size 2 \ + --num_steps 10 --start_step 10 \ + --resume_dir ./checkpoints_fsdp --save_dir ./checkpoints_fsdp_resumed > "${LOG_FSDP_ONLY}.phase2" 2>&1 & +PID2=$! + +echo "FSDP+TP PID=$PID1 | FSDP-only PID=$PID2" +wait $PID1 && echo "Phase 2 FSDP+TP done" || { echo "Phase 2 FSDP+TP failed (exit $?)"; cat "${LOG_FSDP_TP}.phase2"; exit 1; } +wait $PID2 && echo "Phase 2 FSDP-only done" || { echo "Phase 2 FSDP-only failed (exit $?)"; cat "${LOG_FSDP_ONLY}.phase2"; exit 1; } + +# Combine phase logs +cat "${LOG_FSDP_TP}.phase1" "${LOG_FSDP_TP}.phase2" > "$LOG_FSDP_TP" +cat "${LOG_FSDP_ONLY}.phase1" "${LOG_FSDP_ONLY}.phase2" > "$LOG_FSDP_ONLY" + +echo "" +echo "=== Full Loss & Grad Diff (steps 0-19) ===" +git diff --no-index --color --word-diff=color "$LOG_FSDP_TP" "$LOG_FSDP_ONLY" || true diff --git a/run_verify_all.sh b/run_verify_all.sh new file mode 100644 index 000000000000..16aa3267fe9a --- /dev/null +++ b/run_verify_all.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +GREEN='\033[0;32m' +RED='\033[0;31m' +CYAN='\033[0;36m' +YELLOW='\033[1;33m' +BOLD='\033[1m' +DIM='\033[0;90m' +NC='\033[0m' + +SCRIPT="verify_loading.py" +LOGDIR="$(dirname "$0")/verify_logs" +mkdir -p "$LOGDIR" + +NUM_GPUS=$(nvidia-smi -L | wc -l) + +# Job definitions: "mode nproc_per_node" +declare -a JOBS=( + "single_gpu 1" + "fsdp 2" + "tp 2" + "tp_sp 2" + "tp_fsdp 4" + "tp_sp_fsdp 4" +) +MODE_NAMES=(single_gpu fsdp tp tp_sp tp_fsdp tp_sp_fsdp) + +echo -e "${BOLD}==========================================" +echo -e " Verify Loading (${NUM_GPUS} GPUs available)" +echo -e " Modes: ${MODE_NAMES[*]}" +echo -e " Logs: $LOGDIR/" +echo -e "==========================================${NC}" +echo "" + +# ============================================================ +# Round-robin GPU scheduler +# ============================================================ +NEXT_GPU=0 +MASTER_PORT=29500 +PIDS=() +PID_MODES=() + +for job in "${JOBS[@]}"; do + mode=${job% *} + nproc=${job#* } + + # Wait if not enough GPUs left in this round + if [ $((NEXT_GPU + nproc)) -gt "$NUM_GPUS" ]; then + echo -e "${DIM} (waiting for current round to finish...)${NC}" + for pid in "${PIDS[@]}"; do + wait "$pid" 2>/dev/null + done + PIDS=() + NEXT_GPU=0 + fi + + # Build CUDA_VISIBLE_DEVICES range + GPU_END=$((NEXT_GPU + nproc - 1)) + GPUS="" + for g in $(seq "$NEXT_GPU" "$GPU_END"); do + [ -n "$GPUS" ] && GPUS="${GPUS}," + GPUS="${GPUS}${g}" + done + + echo -e " ${CYAN}[${mode}]${NC} GPUs ${NEXT_GPU}-${GPU_END} (nproc=${nproc})" + + if [ "$nproc" -eq 1 ]; then + CUDA_VISIBLE_DEVICES="$GPUS" python "$SCRIPT" --mode "$mode" \ + > "$LOGDIR/${mode}.log" 2>&1 & + else + CUDA_VISIBLE_DEVICES="$GPUS" torchrun \ + --nproc_per_node="$nproc" --master_port="$MASTER_PORT" \ + "$SCRIPT" --mode "$mode" \ + > "$LOGDIR/${mode}.log" 2>&1 & + ((MASTER_PORT++)) + fi + + PIDS+=($!) + PID_MODES+=("$mode") + NEXT_GPU=$((GPU_END + 1)) +done + +# Wait for remaining jobs +echo "" +echo -e "${BOLD}Waiting for all jobs to finish...${NC}" +for i in "${!PIDS[@]}"; do + mode="${PID_MODES[$i]}" + if wait "${PIDS[$i]}"; then + echo -e " ${GREEN}✓${NC} ${mode}" + else + echo -e " ${RED}✗${NC} ${mode} (exit $?)" + fi +done + +# ============================================================ +# Results +# ============================================================ +echo "" +echo -e "${BOLD}=== Results ===${NC}" +for mode in "${MODE_NAMES[@]}"; do + log="$LOGDIR/$mode.log" + loss_before=$(grep -oP 'loss_before = \K[0-9.]+' "$log" 2>/dev/null) + loss_after=$(grep -oP 'loss_after = \K[0-9.]+' "$log" 2>/dev/null) + if grep -q '^PASS' "$log" 2>/dev/null; then + printf " ${GREEN}%-12s PASS (before=%-10s after=%s)${NC}\n" "$mode" "$loss_before" "$loss_after" + elif [ -n "$loss_before" ]; then + diff=$(grep -oP 'diff = \K[0-9.e+-]+' "$log" 2>/dev/null) + printf " ${RED}%-12s FAIL (before=%-10s after=%-10s diff=%s)${NC}\n" "$mode" "$loss_before" "$loss_after" "$diff" + else + printf " ${RED}%-12s ERROR (see log)${NC}\n" "$mode" + fi +done + +# ============================================================ +# Cross-mode loss comparison +# ============================================================ +echo "" +echo -e "${BOLD}=== Cross-mode loss comparison (PASS modes only) ===${NC}" +REF_LOSS="" +ALL_MATCH=1 +for mode in "${MODE_NAMES[@]}"; do + log="$LOGDIR/$mode.log" + # Only include modes where save/load roundtrip passed + if ! grep -q '^PASS' "$log" 2>/dev/null; then + continue + fi + loss=$(grep -oP 'loss_before = \K[0-9.]+' "$log" 2>/dev/null) + if [ -z "$loss" ]; then + continue + fi + if [ -z "$REF_LOSS" ]; then + REF_LOSS="$loss" + printf " ${GREEN}%-12s %s (reference)${NC}\n" "$mode" "$loss" + elif [ "$loss" = "$REF_LOSS" ]; then + printf " ${GREEN}%-12s %s${NC}\n" "$mode" "$loss" + else + printf " ${YELLOW}%-12s %s (differs from %s)${NC}\n" "$mode" "$loss" "$REF_LOSS" + ALL_MATCH=0 + fi +done +if [ "$ALL_MATCH" -eq 1 ] && [ -n "$REF_LOSS" ]; then + echo -e " ${GREEN}All modes produce the same loss.${NC}" +fi + +# Hints for failures +HAS_FAIL=0 +for mode in "${MODE_NAMES[@]}"; do + if ! grep -q '^PASS' "$LOGDIR/$mode.log" 2>/dev/null; then + HAS_FAIL=1 + fi +done +if [ "$HAS_FAIL" -eq 1 ]; then + echo "" + echo -e "${YELLOW}Some modes failed. Check logs:${NC}" + for mode in "${MODE_NAMES[@]}"; do + if ! grep -q '^PASS' "$LOGDIR/$mode.log" 2>/dev/null; then + echo -e " ${YELLOW}cat $LOGDIR/$mode.log${NC}" + fi + done +fi diff --git a/tmp.py b/tmp.py new file mode 100644 index 000000000000..05db1b840b44 --- /dev/null +++ b/tmp.py @@ -0,0 +1 @@ +"&&" diff --git a/tmp_generate.py b/tmp_generate.py new file mode 100644 index 000000000000..9685bed643ed --- /dev/null +++ b/tmp_generate.py @@ -0,0 +1,63 @@ +import argparse +import os + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig + +model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" +# model_id = "Qwen/Qwen3-14B" +# model_id = "Qwen/Qwen3-0.6B" +# model_id = "Qwen/Qwen1.5-MoE-A2.7B-Chat" +# model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +device = torch.device(f"cuda:{rank}") +# Need to be initialized explicitly to use the `barrier` before loading +torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, device_id=rank) + +@record +def main(args): + + distributed_config = DistributedConfig(tp_size=4, tp_plan="auto") + model = AutoModelForCausalLM.from_pretrained(model_id, distributed_config=distributed_config, dtype=torch.bfloat16) + # model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + messages = [ + {"role": "user", "content": "What do you think about life?"}, + ] + inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device) + input_size = inputs.input_ids.shape[-1] + + if args.profile: + # Warmup + with torch.no_grad(): + _ = model.generate(**inputs, max_new_tokens=5, do_sample=False) + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + output = model.generate(**inputs, max_new_tokens=2, do_sample=False) + + if rank == 0: + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) + prof.export_chrome_trace("trace.json") + else: + output = model.generate(**inputs, max_new_tokens=100, do_sample=False) + + text = tokenizer.batch_decode(output[:, input_size:])[0] + if rank == 0: + print(text) + +parser = argparse.ArgumentParser() +parser.add_argument("--profile", action="store_true") +args = parser.parse_args() + +main(args) + +torch.distributed.destroy_process_group() \ No newline at end of file diff --git a/train_fsdp_tp.py b/train_fsdp_tp.py new file mode 100644 index 000000000000..0232f8b3bc3d --- /dev/null +++ b/train_fsdp_tp.py @@ -0,0 +1,125 @@ +# torchrun --nproc_per_node=4 train_fsdp_tp.py + +import argparse +import os + +import torch +from datasets import load_dataset +from torch.distributed.tensor import DTensor +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig +from transformers.distributed.utils import load_optimizer, save_optimizer + +def build_packed_dataset(dataset_name, tokenizer, seq_len, dp_rank, dp_world_size): + """Stream + tokenize + greedy-pack documents into fixed-length (input, label) windows.""" + ds = load_dataset(dataset_name, name="en", split="train", streaming=True) + ds = ds.shard(num_shards=dp_world_size, index=dp_rank) + buf, w = [], seq_len + 1 + + def pack(batch): + for t in batch["text"]: + buf.extend(tokenizer(t)["input_ids"]) + ids, lbls = [], [] + while len(buf) >= w: + ids.append(buf[:seq_len]); lbls.append(buf[1:w]); del buf[:w] + return {"input_ids": ids, "labels": lbls} + + ds = ds.map(pack, batched=True, remove_columns=ds.column_names) + return ds.with_format("torch") + +def build_fixed_batches(dp_rank): + """Load pre-generated fixed batches for a given DP rank.""" + return torch.load(f"fixed_batches_dp{dp_rank}.pt", weights_only=True) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-0.6B") + parser.add_argument("--num_steps", type=int, default=20) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--seq_len", type=int, default=512) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--save_dir", type=str, default="./checkpoints") + parser.add_argument("--tp_size", type=int, default=0, help="Tensor parallel size (0 = disabled)") + parser.add_argument("--fsdp_size", type=int, default=0, help="FSDP size (0 = disabled)") + parser.add_argument("--enable_sp", action="store_true", help="Enable sequence parallelism") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--fixed_batches", action="store_true", help="Use pre-generated fixed batches instead of C4") + parser.add_argument("--resume_dir", type=str, default=None, help="Resume from this checkpoint directory") + parser.add_argument("--start_step", type=int, default=0, help="Starting step number (for logging)") + args = parser.parse_args() + + torch.distributed.init_process_group(backend="nccl") + rank, local_rank = int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + torch.manual_seed(args.seed) + + dc_kwargs = {} + if args.tp_size > 0: + dc_kwargs["tp_size"] = args.tp_size + dc_kwargs["tp_plan"] = "auto" + if args.fsdp_size > 0: + dc_kwargs["fsdp_size"] = args.fsdp_size + dc_kwargs["fsdp_plan"] = "auto" + if args.enable_sp: + dc_kwargs["enable_sequence_parallel"] = True + distributed_config = DistributedConfig(**dc_kwargs) + + load_path = args.resume_dir if args.resume_dir else args.model_name + model = AutoModelForCausalLM.from_pretrained( + load_path, + distributed_config=distributed_config, + torch_dtype=torch.bfloat16, + ) + + dp_rank = model.device_mesh["fsdp"].get_local_rank() if "fsdp" in model.device_mesh.mesh_dim_names else 0 + dp_size = model.device_mesh["fsdp"].size() if "fsdp" in model.device_mesh.mesh_dim_names else 1 + + if args.fixed_batches: + fixed = build_fixed_batches(dp_rank) + else: + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + dataset = build_packed_dataset("allenai/c4", tokenizer, args.seq_len, dp_rank, dp_size) + dataloader = iter(DataLoader(dataset, batch_size=args.batch_size)) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + + if args.resume_dir: + load_optimizer(optimizer, os.path.join(args.resume_dir, "optimizer")) + if rank == 0: + print(f"Resumed optimizer from {args.resume_dir}") + + model.train() + for step in range(args.start_step, args.start_step + args.num_steps): + if args.fixed_batches: + input_ids = fixed[step]["input_ids"].to(f"cuda:{local_rank}") + labels = fixed[step]["labels"].to(f"cuda:{local_rank}") + else: + batch = next(dataloader) + input_ids = batch["input_ids"].to(f"cuda:{local_rank}") + labels = batch["labels"].to(f"cuda:{local_rank}") + loss = model(input_ids, labels=labels).loss + loss.backward() + + # Custom grad clip: convert DTensor grads to local to avoid mixed-mesh torch.stack + grads = [p.grad for p in model.parameters() if p.grad is not None] + local_grads = [g.full_tensor() if isinstance(g, DTensor) else g for g in grads] + total_norm = torch.nn.utils.get_total_norm(local_grads, norm_type=2.0) + torch.nn.utils.clip_grads_with_norm_(grads, max_norm=1.0, total_norm=total_norm) + optimizer.step() + optimizer.zero_grad() + + if rank == 0: + print(f"Step {step:>4d} | Loss: {loss.item():.4f} | Grad norm: {total_norm.item():.4f}") + + # Save model (HF format) and optimizer (DCP) + model.save_pretrained(args.save_dir) + save_optimizer(optimizer, os.path.join(args.save_dir, "optimizer")) + + if rank == 0: + print(f"Saved to {args.save_dir}") + + torch.distributed.destroy_process_group() diff --git a/verify_loading.py b/verify_loading.py new file mode 100644 index 000000000000..ea008f9626f7 --- /dev/null +++ b/verify_loading.py @@ -0,0 +1,137 @@ +# Save/load roundtrip test for distributed models (TP, FSDP, TP+FSDP). +# +# Verifies that save_pretrained → from_pretrained preserves model weights by +# checking that the cross-entropy loss is identical before and after the roundtrip. +# This catches bugs in DTensor gather-on-save and shard-on-read paths. +# +# Usage: +# python verify_loading.py --mode single_gpu +# torchrun --nproc_per_node=2 verify_loading.py --mode fsdp +# torchrun --nproc_per_node=2 verify_loading.py --mode tp +# torchrun --nproc_per_node=4 verify_loading.py --mode tp_fsdp +# MODEL=Qwen/Qwen3-0.6B torchrun --nproc_per_node=2 verify_loading.py --mode tp +import argparse +import os +import shutil + +import torch +from torch.distributed.tensor import DTensor, Replicate + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig + + +parser = argparse.ArgumentParser() +parser.add_argument("--mode", choices=["single_gpu", "fsdp", "tp", "tp_sp", "tp_fsdp", "tp_sp_fsdp"], required=True) +parser.add_argument("--model", type=str, default=None, help="Model ID (or set MODEL env var)") +args = parser.parse_args() + +model_id = args.model or os.environ.get("MODEL") or os.environ.get("MODEL_ID") or "hf-internal-testing/tiny-random-MixtralForCausalLM" + +if args.mode != "single_gpu": + torch.distributed.init_process_group(backend="nccl") + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) +else: + rank = 0 + local_rank = 0 + torch.cuda.set_device(0) + +configs = { + "single_gpu": None, + "fsdp": DistributedConfig(fsdp_size=2, fsdp_plan="auto"), + "tp": DistributedConfig(tp_size=2, tp_plan="auto"), + "tp_sp": DistributedConfig(tp_size=2, tp_plan="auto", enable_sequence_parallel=True), + "tp_fsdp": DistributedConfig(tp_size=2, tp_plan="auto", fsdp_size=2, fsdp_plan="auto"), + "tp_sp_fsdp": DistributedConfig(tp_size=2, tp_plan="auto", fsdp_size=2, fsdp_plan="auto", enable_sequence_parallel=True), +} + +tokenizer = AutoTokenizer.from_pretrained(model_id) +text = "The capital of France is Paris. The largest ocean is the Pacific." + + +def materialize_full_logits(logits: torch.Tensor) -> torch.Tensor: + if isinstance(logits, DTensor): + with torch.no_grad(): + return logits.redistribute(placements=[Replicate()] * logits.device_mesh.ndim, async_op=False).to_local() + return logits + + +def compute_loss(model): + inputs = tokenizer(text, return_tensors="pt").to(f"cuda:{local_rank}") + input_ids = inputs["input_ids"] + # Pad sequence length to a multiple of tp_size so DTensor Shard(1) splits evenly + # across ranks in SP mode. Always pad (even for non-TP modes) so that all modes + # compute on the same input and losses are directly comparable. + max_tp = max((c.tp_size if c is not None else 1) for c in configs.values()) + seq_len = input_ids.shape[1] + if seq_len % max_tp != 0: + pad_len = max_tp - (seq_len % max_tp) + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + input_ids = torch.cat([input_ids, input_ids.new_full((1, pad_len), pad_token_id)], dim=1) + labels = input_ids.clone() + labels[:, seq_len:] = -100 # ignore padding in loss + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0) + + model.eval() + with torch.no_grad(): + logits = model(input_ids, position_ids=position_ids).logits + logits = materialize_full_logits(logits) + loss = torch.nn.functional.cross_entropy( + logits.flatten(0, 1).float(), + labels.flatten(0, 1), + reduction="mean", + ignore_index=-100, + ) + return loss.item() + + +# --- Step 1: Load original model and compute loss --- +model = AutoModelForCausalLM.from_pretrained(model_id, distributed_config=configs[args.mode], dtype=torch.float32) +if args.mode == "single_gpu": + model = model.to("cuda:0") + +loss_before = compute_loss(model) +if rank == 0: + print(f"{args.mode}: loss_before = {loss_before:.6f}") + +# --- Step 2: Save to local dir (shared path across ranks) --- +save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"verify_ckpt_{args.mode}") +if rank == 0: + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + os.makedirs(save_dir) +if args.mode != "single_gpu": + torch.distributed.barrier() +model.save_pretrained(save_dir, is_main_process=(rank == 0)) +if rank == 0: + print(f"{args.mode}: saved to {save_dir}") + +# Ensure all ranks see the saved files before reloading +if args.mode != "single_gpu": + torch.distributed.barrier() + +del model +torch.cuda.empty_cache() + +# --- Step 3: Reload from saved checkpoint and compute loss --- +model2 = AutoModelForCausalLM.from_pretrained(save_dir, distributed_config=configs[args.mode], dtype=torch.float32) +if args.mode == "single_gpu": + model2 = model2.to("cuda:0") + +loss_after = compute_loss(model2) +if rank == 0: + print(f"{args.mode}: loss_after = {loss_after:.6f}") + +# --- Step 4: Compare --- +if rank == 0: + diff = abs(loss_before - loss_after) + print(f"{args.mode}: diff = {diff:.2e}") + if diff < 1e-5: + print("PASS: save/load roundtrip is lossless") + else: + print("FAIL: loss mismatch after save/load roundtrip!") + +if args.mode != "single_gpu": + torch.distributed.destroy_process_group()