From 61db0b246a1667cee979885551d768e8f488ffc9 Mon Sep 17 00:00:00 2001 From: "Rafael Zamora-Resendiz (AMCRD)" Date: Wed, 4 Mar 2026 10:47:00 -0500 Subject: [PATCH 1/2] checkpointing model state at end of continual learning app. --- examples/mnist/mnist.toml | 4 ++++ src/config/configuration.py | 5 ++++- src/driver/continuous_monitor.py | 4 ++++ src/model/torch_model_harness.py | 22 ++++++++++++++++++++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/examples/mnist/mnist.toml b/examples/mnist/mnist.toml index 30bf0c8..414de35 100644 --- a/examples/mnist/mnist.toml +++ b/examples/mnist/mnist.toml @@ -8,6 +8,10 @@ verbosity = "INFO" # DEBUG, INFO, INFO:0-9, WARNING, ERROR, CRITICAL name = "dummy" pretrained_path = "examples/mnist/mnist.pth" +# enables checkpointing +max_ckpts = 0 # Increment to enable +ckpts_path = "output/mnist/" + [data] name = "mnist" path = "" # folder or file later for custom dataset diff --git a/src/config/configuration.py b/src/config/configuration.py index f2d707d..3bc3ece 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -87,7 +87,10 @@ def _select_best_gpu() -> int | None: class ModelCfg: name: str pretrained_path: str - # ckpt: str | None = None # perhaps later support checkpointing files + + # FIFO checkpointing: 0 disables, N keeps last N post-CL snapshots + max_ckpts: int = 0 + ckpts_path: str = "" @dataclass(frozen=True) diff --git a/src/driver/continuous_monitor.py b/src/driver/continuous_monitor.py index 299578c..da6c32f 100644 --- a/src/driver/continuous_monitor.py +++ b/src/driver/continuous_monitor.py @@ -285,6 +285,10 @@ def _handle_drift(self, drift_signal: DriftSignal) -> None: drift_event_id=self.drift_event_count, ) + if self.modelHarness.ckpts_enabled: + ckptpath = self.modelHarness.save_ckpt(event=self.drift_event_count) + self.logger.info(f"* Checkpoint saved to: {ckptpath}", level=0) + self.logger.info("<- Continual learning complete.", level=0) # Optionally reset detector after learning diff --git a/src/model/torch_model_harness.py b/src/model/torch_model_harness.py index ff3a50a..c822d79 100644 --- a/src/model/torch_model_harness.py +++ b/src/model/torch_model_harness.py @@ -1,5 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Optional, Callable, Tuple, List, Dict import torch @@ -173,3 +174,24 @@ def history_eval(self) -> Optional[List[float]]: raise RuntimeError("Empty loader: nothing to evaluate.") return [s / c for s, c in zip(sums, counts)] + + + @property + def ckpts_enabled(self) -> bool: + return self.cfg.model.max_ckpts > 0 and bool(self.cfg.model.ckpts_path) + + def save_ckpt(self, event: int) -> str: + """Persist model state, evict oldest when over budget.""" + d = Path(self.cfg.model.ckpts_path) + d.mkdir(parents=True, exist_ok=True) + + fname = f"drift_adaptation_{event}.pt" + torch.save(self.model.state_dict(), d / fname) + (d / "latest").write_text(fname) + + # Guillotine the oldest survivors + alive = sorted(d.glob("drift_adaptation_*.pt"), key=lambda p: p.stat().st_mtime) + while len(alive) > self.cfg.model.max_ckpts: + alive.pop(0).unlink() + + return str(d / fname) From ed749cf6989641d0335b37d7acdb9eb11eeebdcd Mon Sep 17 00:00:00 2001 From: "Rafael Zamora-Resendiz (AMCRD)" Date: Wed, 4 Mar 2026 12:02:40 -0500 Subject: [PATCH 2/2] Passes ruff and mypy. --- src/model/torch_model_harness.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/model/torch_model_harness.py b/src/model/torch_model_harness.py index c822d79..fee7a53 100644 --- a/src/model/torch_model_harness.py +++ b/src/model/torch_model_harness.py @@ -175,7 +175,6 @@ def history_eval(self) -> Optional[List[float]]: return [s / c for s, c in zip(sums, counts)] - @property def ckpts_enabled(self) -> bool: return self.cfg.model.max_ckpts > 0 and bool(self.cfg.model.ckpts_path)