From fbdf2f1647424e933839dd14f1a8f4fc0526de2a Mon Sep 17 00:00:00 2001 From: Caleb Gross Date: Sat, 21 Mar 2026 01:48:26 -0400 Subject: [PATCH] feat: add --resume support to training script Save full checkpoint state (model + optimizer + step + losses) at each save interval. Support resuming from both new-format and legacy (model-only) checkpoints with automatic detection. Co-Authored-By: Claude Opus 4.6 (1M context) --- training/scripts/train_mnemonic_lm.py | 45 +++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/training/scripts/train_mnemonic_lm.py b/training/scripts/train_mnemonic_lm.py index 604bc69e..b6a58a8c 100644 --- a/training/scripts/train_mnemonic_lm.py +++ b/training/scripts/train_mnemonic_lm.py @@ -224,20 +224,44 @@ def train(config, args): ckpt_dir = Path(f"checkpoints/{args.config}") ckpt_dir.mkdir(parents=True, exist_ok=True) global_step = 0 + global_step_start = 0 lr = args.lr losses = [] + if args.resume: + print(f"\n Resuming from {args.resume}") + ckpt = torch.load(args.resume, map_location=device, weights_only=False) + raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model + if isinstance(ckpt, dict) and 'model_state_dict' in ckpt: + raw_model.load_state_dict(ckpt['model_state_dict']) + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + global_step = ckpt['global_step'] + losses = ckpt.get('losses', []) + else: + # Legacy checkpoint: raw state_dict only + raw_model.load_state_dict(ckpt) + print(" Warning: legacy checkpoint (model weights only, no optimizer state)") + global_step_start = global_step + if losses: + print(f" Resumed at step {global_step}, loss={losses[-1]:.3f}") + else: + print(f" Resumed at step {global_step}") + model.train() optimizer.zero_grad() start_time = time.time() try: from tqdm import tqdm - pbar = tqdm(total=max_steps, desc="Training") + pbar = tqdm(total=max_steps, desc="Training", initial=global_step) except ImportError: pbar = None for input_ids, targets in train_loader: + if global_step_start > 0 and global_step < global_step_start: + global_step += 1 + continue + input_ids = input_ids.to(device) targets = targets.to(device) @@ -283,7 +307,14 @@ def train(config, args): # Periodic checkpoint if global_step % args.save_interval == 0: - torch.save(model.state_dict(), ckpt_dir / f"step_{global_step}.pt") + raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model + torch.save({ + 'model_state_dict': raw_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'global_step': global_step, + 'losses': losses[-100:], + 'args': vars(args), + }, ckpt_dir / f"step_{global_step}.pt") print(f"\n Checkpoint saved at step {global_step}") if global_step % 5000 == 0 and global_step > 0: @@ -319,7 +350,14 @@ def train(config, args): else: print(" FAIL: Loss did not decrease!") - torch.save(model.state_dict(), ckpt_dir / "last.pt") + raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model + torch.save({ + 'model_state_dict': raw_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'global_step': global_step, + 'losses': losses[-100:], + 'args': vars(args), + }, ckpt_dir / "last.pt") print(f" Checkpoint: {ckpt_dir}/last.pt") if not args.no_wandb: @@ -347,6 +385,7 @@ def main(): parser.add_argument("--compile", action="store_true") parser.add_argument("--spoke-lr-mult", type=float, default=2.0, help="Spoke LR multiplier") parser.add_argument("--tokenized-dir", type=str, default=None) + parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint .pt file to resume from") parser.add_argument("--smoke-test", action="store_true", help="Run 1000 steps to verify pipeline") args = parser.parse_args()