Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SIM: Self Improving Model framework
# Apeiron
[![Build Status](https://github.com/AI-ModCon/BaseSim_Framework/actions/workflows/build-test.yml/badge.svg)](https://github.com/AI-ModCon/BaseSim_Framework/actions/workflows/build-test.yml)
[![Coverage Status](https://codecov.io/gh/AI-ModCon/BaseSim_Framework/badge.svg?branch=main)](https://codecov.io/gh/AI-ModCon/BaseSim_Framework?branch=main)

A PyTorch framework for continuous learning that automatically detects concept drift in data streams and adapts models through JVP regularized retraining.
A PyTorch framework for continual learning that automatically detects concept drift in data streams and adapts models through JVP regularized retraining.

## What This Repository Does

Expand All @@ -17,16 +17,35 @@ The pipeline runs on a changing data stream and loops through these stages:
Core modules:

- `src/main.py`: entry point
- `src/config/configuration.py`: TOML/env/CLI config assembly
- `src/driver/continuous_monitor.py`: monitoring + drift loop
- `src/training/continuous_trainer.py`: CL training loop
- `src/training/updater/`: CL update strategies
- `src/drift_detection/`: detectors and detector factory
- `examples/`: concrete model harness implementations
- `src/apeiron/config/configuration.py`: TOML/env/CLI config assembly
- `src/apeiron/driver/continuous_monitor.py`: monitoring + drift loop
- `src/apeiron/training/continuous_trainer.py`: CL training loop
- `src/apeiron/training/updater/`: CL update strategies
- `src/apeiron/drift_detection/`: detectors and detector factory
- `examples/`: standalone example projects (each declares `apeiron` as a dependency)

## Installation

Requires Python `>=3.13,<3.15` and Poetry.
### As a dependency in your project

```toml
# pyproject.toml
[tool.poetry.dependencies]
apeiron = "^0.1.0" # once published to PyPI

# Or as a path dependency during development
apeiron = { path = "../apeiron/", develop = true }
```

```python
from apeiron import BaseModelHarness, ContinuousMonitor, build_config
from apeiron.drift_detection import ADWINDetector
from apeiron.training.updater import BaseUpdater
```

### For development in this repo

Requires Python `>=3.13,<3.14` and Poetry.

```bash
poetry install
Expand All @@ -39,6 +58,7 @@ From the project root:
```bash
poetry run python -m src.main --config examples/mnist/mnist.toml
poetry run python -m src.main --config examples/cifar/cifar10_vit.toml
poetry run python -m src.main --config examples/imagenet/imagenet_vit.toml # requires ImageNet data at data.path
```

## Metrics Logging
Expand Down Expand Up @@ -105,7 +125,7 @@ poetry run mypy .

Platform-specific deployment guides:

- [NERSC Perlmutter](./src/deployment/perlmutter/README.md)
- [NERSC Perlmutter](./src/apeiron/deployment/perlmutter/README.md)

## What `main.py` Does
- Builds the `DummyCNN_MNIST` model defined in `src/model/DummyCNN_MNIST.py`, a cross-entropy loss, and an Adam optimizer.
Expand All @@ -127,4 +147,4 @@ Training logs report the task id, training/test accuracy, and replay-memory accu

Platform-specific deployment guides:

- [OLCF Frontier](./src/deployment/frontier/README.md)
- [OLCF Frontier](./src/apeiron/deployment/frontier/README.md)
41 changes: 34 additions & 7 deletions examples/cifar/cifar10_vgg11.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# mnist.min.toml
version = "1"
# cifar10_vgg11.toml
seed = 1337
device = "auto"
multi_gpu = false
verbosity = "INFO" # DEBUG, INFO, INFO:0-9, WARNING, ERROR, CRITICAL

[model]
name = "vgg11"
name = "vgg11"
pretrained_path = "examples/cifar/cifar10_vgg11.pth"

[data]
Expand All @@ -14,11 +14,38 @@ path = "" # folder or file later for custom dataset

[train]
batch_size = 32
grad_accumulation_steps = 1
num_workers = 4
init_lr = 0.001
max_iter = 300

[continual_learning]
update_mode = "base"
jvp_lambda = 10
jvp_deltax_norm = 1

[continuous_learning]
jvp_reg = 0.001
deltax_norm = 1
max_iter = 300
# ewc hyperparameters
ewc_lambda = 1000.0
ewc_ema_decay = 0.95

# kfac hyperparameters
kfac_lambda = 1e-2
kfac_ema_decay = 0.95

[drift_detection]
detector_name = "ADWINDetector"
detection_interval = 10 # Check drift every 10 batches
aggregation = "mean" # Average metric over 10 batches
metric_index = 0 # Monitor accuracy (0=accuracy, 1=loss)
reset_after_learning = false
max_stream_updates = 20

# ADWIN hyperparameters
adwin_delta = 0.002
adwin_minor_threshold = 0.3
adwin_moderate_threshold = 0.6

[visualization]
baseline = 90.0
input = "output/cifar.csv"
output = "output/cifar10_vgg11_dashboard.png"
6 changes: 3 additions & 3 deletions examples/cifar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, ConcatDataset

from model.torch_model_harness import BaseModelHarness
from config.configuration import Config
from apeiron.model.torch_model_harness import BaseModelHarness
from apeiron.config.configuration import Config
from examples.cifar.src.utils import (
get_cifar_train,
get_cifar_val,
Expand All @@ -16,7 +16,7 @@
sample_aug,
)
from examples.cifar.src.utils import load_model
from evaluation.metrics import accuracy
from apeiron.evaluation.metrics import accuracy


class VisionModelCifar(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from config.configuration import Config
from apeiron.config.configuration import Config
from examples.cifar.src import cnns, vision_transformers


Expand Down
6 changes: 3 additions & 3 deletions examples/imagenet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, ConcatDataset

from model.torch_model_harness import BaseModelHarness
from config.configuration import Config
from apeiron.model.torch_model_harness import BaseModelHarness
from apeiron.config.configuration import Config
from examples.imagenet.src.utils import (
get_imagenet_train,
get_imagenet_val,
Expand All @@ -16,7 +16,7 @@
sample_aug,
)
from examples.imagenet.src.utils import load_model
from evaluation.metrics import accuracy
from apeiron.evaluation.metrics import accuracy


class VisionModelImageNet(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion examples/imagenet/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from config.configuration import Config
from apeiron.config.configuration import Config
from examples.cifar.src import cnns, vision_transformers


Expand Down
6 changes: 3 additions & 3 deletions examples/mnist/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, ConcatDataset

from model.torch_model_harness import BaseModelHarness
from config.configuration import Config
from apeiron.model.torch_model_harness import BaseModelHarness
from apeiron.config.configuration import Config
from examples.mnist.utils import (
get_mnist_train,
get_mnist_val,
Expand All @@ -17,7 +17,7 @@
make_loader,
sample_aug,
)
from evaluation.metrics import accuracy
from apeiron.evaluation.metrics import accuracy


class Cnn(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions examples/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from config.configuration import Config
from model.torch_model_harness import BaseModelHarness
from apeiron.config.configuration import Config
from apeiron.model.torch_model_harness import BaseModelHarness


def get_example(cfg: Config) -> BaseModelHarness:
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[project]
name = "basesim-framework"
name = "apeiron"
version = "0.1.0"
description = ""
description = "A PyTorch continual learning framework for real-time concept drift detection and model adaptation."
authors = [
{name = "Your Name",email = "you@example.com"}
]
readme = "README.md"
requires-python = ">=3.13,<3.15"
requires-python = ">=3.13,<3.14"
dependencies = [
"torch (>=2.0)",
"numpy (>=2.3.4,<3.0.0)",
Expand Down Expand Up @@ -41,7 +41,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
packages = [
{ include = "*", from = "src" },
{ include = "apeiron", from = "src" },
]

[tool.ruff]
Expand Down
56 changes: 56 additions & 0 deletions src/apeiron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Apeiron: A PyTorch continual learning framework for real-time concept drift detection and model adaptation."""

from apeiron.config.configuration import (
Config,
ModelCfg,
DataCfg,
TrainCfg,
ContinualLearningCfg,
DriftDetectionCfg,
VisualizationCfg,
LoggingCfg,
build_config,
)
from apeiron.model.torch_model_harness import BaseModelHarness
from apeiron.driver.continuous_monitor import ContinuousMonitor
from apeiron.drift_detection import (
BaseDriftDetector,
DriftSignal,
LearningRegime,
ADWINDetector,
KSWINDetector,
PageHinkleyDetector,
ModelPerformanceDetector,
EnsembleDetector,
ModelEvalDetector,
)
from apeiron.training import ContinuousTrainer
from apeiron.training.updater import BaseUpdater
from apeiron.logger import Logger, get_logger

__all__ = [
"Config",
"ModelCfg",
"DataCfg",
"TrainCfg",
"ContinualLearningCfg",
"DriftDetectionCfg",
"VisualizationCfg",
"LoggingCfg",
"build_config",
"BaseModelHarness",
"ContinuousMonitor",
"BaseDriftDetector",
"DriftSignal",
"LearningRegime",
"ADWINDetector",
"KSWINDetector",
"PageHinkleyDetector",
"ModelPerformanceDetector",
"EnsembleDetector",
"ModelEvalDetector",
"ContinuousTrainer",
"BaseUpdater",
"Logger",
"get_logger",
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from logger.logger import MetricsBackend
from apeiron.logger.logger import MetricsBackend


def get_available_device(multi_gpu: bool = False) -> torch.device:
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
- EnsembleDetector: Combine multiple detectors
"""

from drift_detection.detectors.base import (
from apeiron.drift_detection.detectors.base import (
BaseDriftDetector,
DriftSignal,
LearningRegime,
)
from drift_detection.detectors.statistical_detectors import (
from apeiron.drift_detection.detectors.statistical_detectors import (
ADWINDetector,
KSWINDetector,
PageHinkleyDetector,
)
from drift_detection.detectors.model_performance_detector import (
from apeiron.drift_detection.detectors.model_performance_detector import (
ModelPerformanceDetector,
EnsembleDetector,
ModelEvalDetector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Optional, List
from evidently import Report
from evidently.presets import DataDriftPreset
from drift_detection.detectors.base import (
from apeiron.drift_detection.detectors.base import (
BaseDriftDetector,
DriftSignal,
LearningRegime,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
from river import drift as river_drift
from drift_detection.detectors.base import (
from apeiron.drift_detection.detectors.base import (
BaseDriftDetector,
DriftSignal,
LearningRegime,
Expand Down
Loading
Loading