Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions run_compare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
set -euo pipefail

SCRIPT="train_fsdp_tp.py"
LOG_FSDP_TP="log.txt"
LOG_FSDP_ONLY="ref.txt"

MODEL_NAME="${MODEL_NAME:-hf-internal-testing/tiny-random-MixtralForCausalLM}"
COMMON_ARGS="--model_name $MODEL_NAME --lr 3e-4 --seed 42"

rm -rf ./checkpoints_tp ./checkpoints_tp_resumed ./checkpoints_fsdp ./checkpoints_fsdp_resumed

echo "=== Phase 1: Train steps 0-9, save checkpoint ==="
echo "--- Launching FSDP+TP and FSDP-only in parallel ---"

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29500 \
$SCRIPT $COMMON_ARGS --fsdp_size 2 --tp_size 2 --enable_sp \
--num_steps 10 --save_dir ./checkpoints_tp > "${LOG_FSDP_TP}.phase1" 2>&1 &
PID1=$!

CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29501 \
$SCRIPT $COMMON_ARGS --fsdp_size 2 \
--num_steps 10 --save_dir ./checkpoints_fsdp > "${LOG_FSDP_ONLY}.phase1" 2>&1 &
PID2=$!

echo "FSDP+TP PID=$PID1 | FSDP-only PID=$PID2"
wait $PID1 && echo "Phase 1 FSDP+TP done" || { echo "Phase 1 FSDP+TP failed (exit $?)"; cat "${LOG_FSDP_TP}.phase1"; exit 1; }
wait $PID2 && echo "Phase 1 FSDP-only done" || { echo "Phase 1 FSDP-only failed (exit $?)"; cat "${LOG_FSDP_ONLY}.phase1"; exit 1; }

echo ""
echo "=== Phase 2: Resume from checkpoint, train steps 10-19, save ==="
echo "--- Launching FSDP+TP and FSDP-only in parallel ---"

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29500 \
$SCRIPT $COMMON_ARGS --fsdp_size 2 --tp_size 2 --enable_sp \
--num_steps 10 --start_step 10 \
--resume_dir ./checkpoints_tp --save_dir ./checkpoints_tp_resumed > "${LOG_FSDP_TP}.phase2" 2>&1 &
PID1=$!

CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29501 \
$SCRIPT $COMMON_ARGS --fsdp_size 2 \
--num_steps 10 --start_step 10 \
--resume_dir ./checkpoints_fsdp --save_dir ./checkpoints_fsdp_resumed > "${LOG_FSDP_ONLY}.phase2" 2>&1 &
PID2=$!

echo "FSDP+TP PID=$PID1 | FSDP-only PID=$PID2"
wait $PID1 && echo "Phase 2 FSDP+TP done" || { echo "Phase 2 FSDP+TP failed (exit $?)"; cat "${LOG_FSDP_TP}.phase2"; exit 1; }
wait $PID2 && echo "Phase 2 FSDP-only done" || { echo "Phase 2 FSDP-only failed (exit $?)"; cat "${LOG_FSDP_ONLY}.phase2"; exit 1; }

# Combine phase logs
cat "${LOG_FSDP_TP}.phase1" "${LOG_FSDP_TP}.phase2" > "$LOG_FSDP_TP"
cat "${LOG_FSDP_ONLY}.phase1" "${LOG_FSDP_ONLY}.phase2" > "$LOG_FSDP_ONLY"

echo ""
echo "=== Full Loss & Grad Diff (steps 0-19) ==="
git diff --no-index --color --word-diff=color "$LOG_FSDP_TP" "$LOG_FSDP_ONLY" || true
160 changes: 160 additions & 0 deletions run_verify_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/bin/bash

GREEN='\033[0;32m'
RED='\033[0;31m'
CYAN='\033[0;36m'
YELLOW='\033[1;33m'
BOLD='\033[1m'
DIM='\033[0;90m'
NC='\033[0m'

SCRIPT="verify_loading.py"
LOGDIR="$(dirname "$0")/verify_logs"
mkdir -p "$LOGDIR"

NUM_GPUS=$(nvidia-smi -L | wc -l)

# Job definitions: "mode nproc_per_node"
declare -a JOBS=(
"single_gpu 1"
"fsdp 2"
"tp 2"
"tp_sp 2"
"tp_fsdp 4"
"tp_sp_fsdp 4"
)
MODE_NAMES=(single_gpu fsdp tp tp_sp tp_fsdp tp_sp_fsdp)

echo -e "${BOLD}=========================================="
echo -e " Verify Loading (${NUM_GPUS} GPUs available)"
echo -e " Modes: ${MODE_NAMES[*]}"
echo -e " Logs: $LOGDIR/"
echo -e "==========================================${NC}"
echo ""

# ============================================================
# Round-robin GPU scheduler
# ============================================================
NEXT_GPU=0
MASTER_PORT=29500
PIDS=()
PID_MODES=()

for job in "${JOBS[@]}"; do
mode=${job% *}
nproc=${job#* }

# Wait if not enough GPUs left in this round
if [ $((NEXT_GPU + nproc)) -gt "$NUM_GPUS" ]; then
echo -e "${DIM} (waiting for current round to finish...)${NC}"
for pid in "${PIDS[@]}"; do
wait "$pid" 2>/dev/null
done
PIDS=()
NEXT_GPU=0
fi

# Build CUDA_VISIBLE_DEVICES range
GPU_END=$((NEXT_GPU + nproc - 1))
GPUS=""
for g in $(seq "$NEXT_GPU" "$GPU_END"); do
[ -n "$GPUS" ] && GPUS="${GPUS},"
GPUS="${GPUS}${g}"
done

echo -e " ${CYAN}[${mode}]${NC} GPUs ${NEXT_GPU}-${GPU_END} (nproc=${nproc})"

if [ "$nproc" -eq 1 ]; then
CUDA_VISIBLE_DEVICES="$GPUS" python "$SCRIPT" --mode "$mode" \
> "$LOGDIR/${mode}.log" 2>&1 &
else
CUDA_VISIBLE_DEVICES="$GPUS" torchrun \
--nproc_per_node="$nproc" --master_port="$MASTER_PORT" \
"$SCRIPT" --mode "$mode" \
> "$LOGDIR/${mode}.log" 2>&1 &
((MASTER_PORT++))
fi

PIDS+=($!)
PID_MODES+=("$mode")
NEXT_GPU=$((GPU_END + 1))
done

# Wait for remaining jobs
echo ""
echo -e "${BOLD}Waiting for all jobs to finish...${NC}"
for i in "${!PIDS[@]}"; do
mode="${PID_MODES[$i]}"
if wait "${PIDS[$i]}"; then
echo -e " ${GREEN}βœ“${NC} ${mode}"
else
echo -e " ${RED}βœ—${NC} ${mode} (exit $?)"
fi
done

# ============================================================
# Results
# ============================================================
echo ""
echo -e "${BOLD}=== Results ===${NC}"
for mode in "${MODE_NAMES[@]}"; do
log="$LOGDIR/$mode.log"
loss_before=$(grep -oP 'loss_before = \K[0-9.]+' "$log" 2>/dev/null)
loss_after=$(grep -oP 'loss_after = \K[0-9.]+' "$log" 2>/dev/null)
if grep -q '^PASS' "$log" 2>/dev/null; then
printf " ${GREEN}%-12s PASS (before=%-10s after=%s)${NC}\n" "$mode" "$loss_before" "$loss_after"
elif [ -n "$loss_before" ]; then
diff=$(grep -oP 'diff = \K[0-9.e+-]+' "$log" 2>/dev/null)
printf " ${RED}%-12s FAIL (before=%-10s after=%-10s diff=%s)${NC}\n" "$mode" "$loss_before" "$loss_after" "$diff"
else
printf " ${RED}%-12s ERROR (see log)${NC}\n" "$mode"
fi
done

# ============================================================
# Cross-mode loss comparison
# ============================================================
echo ""
echo -e "${BOLD}=== Cross-mode loss comparison (PASS modes only) ===${NC}"
REF_LOSS=""
ALL_MATCH=1
for mode in "${MODE_NAMES[@]}"; do
log="$LOGDIR/$mode.log"
# Only include modes where save/load roundtrip passed
if ! grep -q '^PASS' "$log" 2>/dev/null; then
continue
fi
loss=$(grep -oP 'loss_before = \K[0-9.]+' "$log" 2>/dev/null)
if [ -z "$loss" ]; then
continue
fi
if [ -z "$REF_LOSS" ]; then
REF_LOSS="$loss"
printf " ${GREEN}%-12s %s (reference)${NC}\n" "$mode" "$loss"
elif [ "$loss" = "$REF_LOSS" ]; then
printf " ${GREEN}%-12s %s${NC}\n" "$mode" "$loss"
else
printf " ${YELLOW}%-12s %s (differs from %s)${NC}\n" "$mode" "$loss" "$REF_LOSS"
ALL_MATCH=0
fi
done
if [ "$ALL_MATCH" -eq 1 ] && [ -n "$REF_LOSS" ]; then
echo -e " ${GREEN}All modes produce the same loss.${NC}"
fi

# Hints for failures
HAS_FAIL=0
for mode in "${MODE_NAMES[@]}"; do
if ! grep -q '^PASS' "$LOGDIR/$mode.log" 2>/dev/null; then
HAS_FAIL=1
fi
done
if [ "$HAS_FAIL" -eq 1 ]; then
echo ""
echo -e "${YELLOW}Some modes failed. Check logs:${NC}"
for mode in "${MODE_NAMES[@]}"; do
if ! grep -q '^PASS' "$LOGDIR/$mode.log" 2>/dev/null; then
echo -e " ${YELLOW}cat $LOGDIR/$mode.log${NC}"
fi
done
fi
1 change: 1 addition & 0 deletions tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"&&"
63 changes: 63 additions & 0 deletions tmp_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import os

import torch
from torch.distributed.elastic.multiprocessing.errors import record

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.distributed import DistributedConfig

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# model_id = "Qwen/Qwen3-14B"
# model_id = "Qwen/Qwen3-0.6B"
# model_id = "Qwen/Qwen1.5-MoE-A2.7B-Chat"
# model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507"

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
# Need to be initialized explicitly to use the `barrier` before loading
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, device_id=rank)

@record
def main(args):

distributed_config = DistributedConfig(tp_size=4, tp_plan="auto")
model = AutoModelForCausalLM.from_pretrained(model_id, distributed_config=distributed_config, dtype=torch.bfloat16)
# model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

messages = [
{"role": "user", "content": "What do you think about life?"},
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
input_size = inputs.input_ids.shape[-1]

if args.profile:
# Warmup
with torch.no_grad():
_ = model.generate(**inputs, max_new_tokens=5, do_sample=False)

with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
output = model.generate(**inputs, max_new_tokens=2, do_sample=False)

if rank == 0:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
prof.export_chrome_trace("trace.json")
else:
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)

text = tokenizer.batch_decode(output[:, input_size:])[0]
if rank == 0:
print(text)

parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()

main(args)

torch.distributed.destroy_process_group()
Loading
Loading