Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions training/scripts/train_mnemonic_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,20 +224,44 @@ def train(config, args):
ckpt_dir = Path(f"checkpoints/{args.config}")
ckpt_dir.mkdir(parents=True, exist_ok=True)
global_step = 0
global_step_start = 0
lr = args.lr
losses = []

if args.resume:
print(f"\n Resuming from {args.resume}")
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
raw_model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
global_step = ckpt['global_step']
losses = ckpt.get('losses', [])
else:
# Legacy checkpoint: raw state_dict only
raw_model.load_state_dict(ckpt)
print(" Warning: legacy checkpoint (model weights only, no optimizer state)")
global_step_start = global_step
if losses:
print(f" Resumed at step {global_step}, loss={losses[-1]:.3f}")
else:
print(f" Resumed at step {global_step}")

model.train()
optimizer.zero_grad()
start_time = time.time()

try:
from tqdm import tqdm
pbar = tqdm(total=max_steps, desc="Training")
pbar = tqdm(total=max_steps, desc="Training", initial=global_step)
except ImportError:
pbar = None

for input_ids, targets in train_loader:
if global_step_start > 0 and global_step < global_step_start:
global_step += 1
continue

input_ids = input_ids.to(device)
targets = targets.to(device)

Expand Down Expand Up @@ -283,7 +307,14 @@ def train(config, args):

# Periodic checkpoint
if global_step % args.save_interval == 0:
torch.save(model.state_dict(), ckpt_dir / f"step_{global_step}.pt")
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
torch.save({
'model_state_dict': raw_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step,
'losses': losses[-100:],
'args': vars(args),
}, ckpt_dir / f"step_{global_step}.pt")
print(f"\n Checkpoint saved at step {global_step}")

if global_step % 5000 == 0 and global_step > 0:
Expand Down Expand Up @@ -319,7 +350,14 @@ def train(config, args):
else:
print(" FAIL: Loss did not decrease!")

torch.save(model.state_dict(), ckpt_dir / "last.pt")
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
torch.save({
'model_state_dict': raw_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step,
'losses': losses[-100:],
'args': vars(args),
}, ckpt_dir / "last.pt")
print(f" Checkpoint: {ckpt_dir}/last.pt")

if not args.no_wandb:
Expand Down Expand Up @@ -347,6 +385,7 @@ def main():
parser.add_argument("--compile", action="store_true")
parser.add_argument("--spoke-lr-mult", type=float, default=2.0, help="Spoke LR multiplier")
parser.add_argument("--tokenized-dir", type=str, default=None)
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint .pt file to resume from")
parser.add_argument("--smoke-test", action="store_true", help="Run 1000 steps to verify pipeline")
args = parser.parse_args()

Expand Down