diff --git a/examples/mnist/mnist.toml b/examples/mnist/mnist.toml index 30bf0c8..414de35 100644 --- a/examples/mnist/mnist.toml +++ b/examples/mnist/mnist.toml @@ -8,6 +8,10 @@ verbosity = "INFO" # DEBUG, INFO, INFO:0-9, WARNING, ERROR, CRITICAL name = "dummy" pretrained_path = "examples/mnist/mnist.pth" +# enables checkpointing +max_ckpts = 0 # Increment to enable +ckpts_path = "output/mnist/" + [data] name = "mnist" path = "" # folder or file later for custom dataset diff --git a/src/config/configuration.py b/src/config/configuration.py index f2d707d..3bc3ece 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -87,7 +87,10 @@ def _select_best_gpu() -> int | None: class ModelCfg: name: str pretrained_path: str - # ckpt: str | None = None # perhaps later support checkpointing files + + # FIFO checkpointing: 0 disables, N keeps last N post-CL snapshots + max_ckpts: int = 0 + ckpts_path: str = "" @dataclass(frozen=True) diff --git a/src/driver/continuous_monitor.py b/src/driver/continuous_monitor.py index 299578c..da6c32f 100644 --- a/src/driver/continuous_monitor.py +++ b/src/driver/continuous_monitor.py @@ -285,6 +285,10 @@ def _handle_drift(self, drift_signal: DriftSignal) -> None: drift_event_id=self.drift_event_count, ) + if self.modelHarness.ckpts_enabled: + ckptpath = self.modelHarness.save_ckpt(event=self.drift_event_count) + self.logger.info(f"* Checkpoint saved to: {ckptpath}", level=0) + self.logger.info("<- Continual learning complete.", level=0) # Optionally reset detector after learning diff --git a/src/model/torch_model_harness.py b/src/model/torch_model_harness.py index ff3a50a..fee7a53 100644 --- a/src/model/torch_model_harness.py +++ b/src/model/torch_model_harness.py @@ -1,5 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Optional, Callable, Tuple, List, Dict import torch @@ -173,3 +174,23 @@ def history_eval(self) -> Optional[List[float]]: raise RuntimeError("Empty loader: nothing to evaluate.") return [s / c for s, c in zip(sums, counts)] + + @property + def ckpts_enabled(self) -> bool: + return self.cfg.model.max_ckpts > 0 and bool(self.cfg.model.ckpts_path) + + def save_ckpt(self, event: int) -> str: + """Persist model state, evict oldest when over budget.""" + d = Path(self.cfg.model.ckpts_path) + d.mkdir(parents=True, exist_ok=True) + + fname = f"drift_adaptation_{event}.pt" + torch.save(self.model.state_dict(), d / fname) + (d / "latest").write_text(fname) + + # Guillotine the oldest survivors + alive = sorted(d.glob("drift_adaptation_*.pt"), key=lambda p: p.stat().st_mtime) + while len(alive) > self.cfg.model.max_ckpts: + alive.pop(0).unlink() + + return str(d / fname)