From cead3e9f458225c95335129f700a20d32aa7307a Mon Sep 17 00:00:00 2001 From: andrewfayres Date: Wed, 1 Apr 2026 09:55:32 -0500 Subject: [PATCH] Modifying package so we can start publishing project. --- README.md | 42 ++++++++++---- examples/cifar/cifar10_vgg11.toml | 41 +++++++++++--- examples/cifar/model.py | 6 +- examples/cifar/src/utils.py | 2 +- examples/imagenet/model.py | 6 +- examples/imagenet/src/utils.py | 2 +- examples/mnist/model.py | 6 +- examples/utils.py | 4 +- poetry.lock | 4 +- pyproject.toml | 8 +-- src/apeiron/__init__.py | 56 +++++++++++++++++++ src/{ => apeiron}/config/__init__.py | 0 src/{ => apeiron}/config/configuration.py | 2 +- src/{ => apeiron}/deployment/__init__.py | 0 .../deployment/frontier/README.md | 0 .../deployment/frontier/install_venv.sh | 0 .../deployment/frontier/mnist_example.sbatch | 0 .../deployment/perlmutter/README.md | 0 .../deployment/perlmutter/install_venv.sh | 0 .../perlmutter/mnist_example.sbatch | 0 src/{ => apeiron}/drift_detection/.gitkeep | 0 src/{ => apeiron}/drift_detection/__init__.py | 6 +- .../drift_detection/detectors/base.py | 0 .../detectors/model_performance_detector.py | 2 +- .../detectors/statistical_detectors.py | 2 +- .../drift_detection/load_drift_detector.py | 16 +++--- src/apeiron/driver/__init__.py | 1 + .../driver/continuous_monitor.py | 14 ++--- src/{ => apeiron}/evaluation/__init__.py | 0 src/{ => apeiron}/evaluation/evaluation.py | 0 src/{ => apeiron}/evaluation/metrics.py | 0 src/{ => apeiron}/logger/__init__.py | 8 +-- src/{ => apeiron}/logger/console_logger.py | 0 src/{ => apeiron}/logger/logger.py | 10 ++-- src/{ => apeiron}/logger/mlflow_logger.py | 4 +- src/{ => apeiron}/logger/wandb_logger.py | 2 +- src/apeiron/model/__init__.py | 1 + .../model/torch_model_harness.py | 2 +- src/{ => apeiron}/profilers/README.md | 0 src/apeiron/profilers/__init__.py | 1 + src/{ => apeiron}/profilers/aten_flops_map.py | 0 src/{ => apeiron}/profilers/count_flops.py | 2 +- src/apeiron/training/__init__.py | 1 + .../training/continuous_trainer.py | 10 ++-- src/apeiron/training/updater/__init__.py | 1 + src/{ => apeiron}/training/updater/base.py | 4 +- .../training/updater/create_updater.py | 14 ++--- src/{ => apeiron}/training/updater/ewc.py | 6 +- src/{ => apeiron}/training/updater/jvp_reg.py | 6 +- src/{ => apeiron}/training/updater/kfac.py | 6 +- .../training/updater/no_updater.py | 2 +- src/main.py | 6 +- src/profilers/__init__.py | 1 - src/training/__init__.py | 1 - src/training/updater/__init__.py | 1 - tests/conftest.py | 6 +- tests/test_config.py | 8 +-- tests/test_continuous_monitor.py | 8 +-- tests/test_continuous_trainer.py | 4 +- tests/test_drift_detection.py | 10 ++-- tests/test_evaluation.py | 2 +- tests/test_logger.py | 4 +- tests/test_model_harness.py | 2 +- tests/test_profiler.py | 2 +- tests/test_updaters.py | 14 ++--- tests/test_valiadation_tests.py | 10 ++-- 66 files changed, 242 insertions(+), 137 deletions(-) create mode 100644 src/apeiron/__init__.py rename src/{ => apeiron}/config/__init__.py (100%) rename src/{ => apeiron}/config/configuration.py (99%) rename src/{ => apeiron}/deployment/__init__.py (100%) rename src/{ => apeiron}/deployment/frontier/README.md (100%) rename src/{ => apeiron}/deployment/frontier/install_venv.sh (100%) rename src/{ => apeiron}/deployment/frontier/mnist_example.sbatch (100%) rename src/{ => apeiron}/deployment/perlmutter/README.md (100%) rename src/{ => apeiron}/deployment/perlmutter/install_venv.sh (100%) rename src/{ => apeiron}/deployment/perlmutter/mnist_example.sbatch (100%) rename src/{ => apeiron}/drift_detection/.gitkeep (100%) rename src/{ => apeiron}/drift_detection/__init__.py (84%) rename src/{ => apeiron}/drift_detection/detectors/base.py (100%) rename src/{ => apeiron}/drift_detection/detectors/model_performance_detector.py (99%) rename src/{ => apeiron}/drift_detection/detectors/statistical_detectors.py (99%) rename src/{ => apeiron}/drift_detection/load_drift_detector.py (78%) create mode 100644 src/apeiron/driver/__init__.py rename src/{ => apeiron}/driver/continuous_monitor.py (97%) rename src/{ => apeiron}/evaluation/__init__.py (100%) rename src/{ => apeiron}/evaluation/evaluation.py (100%) rename src/{ => apeiron}/evaluation/metrics.py (100%) rename src/{ => apeiron}/logger/__init__.py (61%) rename src/{ => apeiron}/logger/console_logger.py (100%) rename src/{ => apeiron}/logger/logger.py (95%) rename src/{ => apeiron}/logger/mlflow_logger.py (98%) rename src/{ => apeiron}/logger/wandb_logger.py (99%) create mode 100644 src/apeiron/model/__init__.py rename src/{ => apeiron}/model/torch_model_harness.py (99%) rename src/{ => apeiron}/profilers/README.md (100%) create mode 100644 src/apeiron/profilers/__init__.py rename src/{ => apeiron}/profilers/aten_flops_map.py (100%) rename src/{ => apeiron}/profilers/count_flops.py (99%) create mode 100644 src/apeiron/training/__init__.py rename src/{ => apeiron}/training/continuous_trainer.py (96%) create mode 100644 src/apeiron/training/updater/__init__.py rename src/{ => apeiron}/training/updater/base.py (94%) rename src/{ => apeiron}/training/updater/create_updater.py (73%) rename src/{ => apeiron}/training/updater/ewc.py (96%) rename src/{ => apeiron}/training/updater/jvp_reg.py (96%) rename src/{ => apeiron}/training/updater/kfac.py (97%) rename src/{ => apeiron}/training/updater/no_updater.py (89%) delete mode 100644 src/profilers/__init__.py delete mode 100644 src/training/__init__.py delete mode 100644 src/training/updater/__init__.py diff --git a/README.md b/README.md index 2b0f5d0..a2faef3 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# SIM: Self Improving Model framework +# Apeiron [![Build Status](https://github.com/AI-ModCon/BaseSim_Framework/actions/workflows/build-test.yml/badge.svg)](https://github.com/AI-ModCon/BaseSim_Framework/actions/workflows/build-test.yml) [![Coverage Status](https://codecov.io/gh/AI-ModCon/BaseSim_Framework/badge.svg?branch=main)](https://codecov.io/gh/AI-ModCon/BaseSim_Framework?branch=main) -A PyTorch framework for continuous learning that automatically detects concept drift in data streams and adapts models through JVP regularized retraining. +A PyTorch framework for continual learning that automatically detects concept drift in data streams and adapts models through JVP regularized retraining. ## What This Repository Does @@ -17,16 +17,35 @@ The pipeline runs on a changing data stream and loops through these stages: Core modules: - `src/main.py`: entry point -- `src/config/configuration.py`: TOML/env/CLI config assembly -- `src/driver/continuous_monitor.py`: monitoring + drift loop -- `src/training/continuous_trainer.py`: CL training loop -- `src/training/updater/`: CL update strategies -- `src/drift_detection/`: detectors and detector factory -- `examples/`: concrete model harness implementations +- `src/apeiron/config/configuration.py`: TOML/env/CLI config assembly +- `src/apeiron/driver/continuous_monitor.py`: monitoring + drift loop +- `src/apeiron/training/continuous_trainer.py`: CL training loop +- `src/apeiron/training/updater/`: CL update strategies +- `src/apeiron/drift_detection/`: detectors and detector factory +- `examples/`: standalone example projects (each declares `apeiron` as a dependency) ## Installation -Requires Python `>=3.13,<3.15` and Poetry. +### As a dependency in your project + +```toml +# pyproject.toml +[tool.poetry.dependencies] +apeiron = "^0.1.0" # once published to PyPI + +# Or as a path dependency during development +apeiron = { path = "../apeiron/", develop = true } +``` + +```python +from apeiron import BaseModelHarness, ContinuousMonitor, build_config +from apeiron.drift_detection import ADWINDetector +from apeiron.training.updater import BaseUpdater +``` + +### For development in this repo + +Requires Python `>=3.13,<3.14` and Poetry. ```bash poetry install @@ -39,6 +58,7 @@ From the project root: ```bash poetry run python -m src.main --config examples/mnist/mnist.toml poetry run python -m src.main --config examples/cifar/cifar10_vit.toml +poetry run python -m src.main --config examples/imagenet/imagenet_vit.toml # requires ImageNet data at data.path ``` ## Metrics Logging @@ -105,7 +125,7 @@ poetry run mypy . Platform-specific deployment guides: -- [NERSC Perlmutter](./src/deployment/perlmutter/README.md) +- [NERSC Perlmutter](./src/apeiron/deployment/perlmutter/README.md) ## What `main.py` Does - Builds the `DummyCNN_MNIST` model defined in `src/model/DummyCNN_MNIST.py`, a cross-entropy loss, and an Adam optimizer. @@ -127,4 +147,4 @@ Training logs report the task id, training/test accuracy, and replay-memory accu Platform-specific deployment guides: -- [OLCF Frontier](./src/deployment/frontier/README.md) +- [OLCF Frontier](./src/apeiron/deployment/frontier/README.md) diff --git a/examples/cifar/cifar10_vgg11.toml b/examples/cifar/cifar10_vgg11.toml index e89ebba..741878c 100644 --- a/examples/cifar/cifar10_vgg11.toml +++ b/examples/cifar/cifar10_vgg11.toml @@ -1,11 +1,11 @@ -# mnist.min.toml -version = "1" +# cifar10_vgg11.toml seed = 1337 device = "auto" multi_gpu = false +verbosity = "INFO" # DEBUG, INFO, INFO:0-9, WARNING, ERROR, CRITICAL [model] -name = "vgg11" +name = "vgg11" pretrained_path = "examples/cifar/cifar10_vgg11.pth" [data] @@ -14,11 +14,38 @@ path = "" # folder or file later for custom dataset [train] batch_size = 32 +grad_accumulation_steps = 1 num_workers = 4 init_lr = 0.001 +max_iter = 300 +[continual_learning] +update_mode = "base" +jvp_lambda = 10 +jvp_deltax_norm = 1 -[continuous_learning] -jvp_reg = 0.001 -deltax_norm = 1 -max_iter = 300 \ No newline at end of file +# ewc hyperparameters +ewc_lambda = 1000.0 +ewc_ema_decay = 0.95 + +# kfac hyperparameters +kfac_lambda = 1e-2 +kfac_ema_decay = 0.95 + +[drift_detection] +detector_name = "ADWINDetector" +detection_interval = 10 # Check drift every 10 batches +aggregation = "mean" # Average metric over 10 batches +metric_index = 0 # Monitor accuracy (0=accuracy, 1=loss) +reset_after_learning = false +max_stream_updates = 20 + +# ADWIN hyperparameters +adwin_delta = 0.002 +adwin_minor_threshold = 0.3 +adwin_moderate_threshold = 0.6 + +[visualization] +baseline = 90.0 +input = "output/cifar.csv" +output = "output/cifar10_vgg11_dashboard.png" diff --git a/examples/cifar/model.py b/examples/cifar/model.py index ded7ef1..2f317fb 100644 --- a/examples/cifar/model.py +++ b/examples/cifar/model.py @@ -5,8 +5,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, ConcatDataset -from model.torch_model_harness import BaseModelHarness -from config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness +from apeiron.config.configuration import Config from examples.cifar.src.utils import ( get_cifar_train, get_cifar_val, @@ -16,7 +16,7 @@ sample_aug, ) from examples.cifar.src.utils import load_model -from evaluation.metrics import accuracy +from apeiron.evaluation.metrics import accuracy class VisionModelCifar(nn.Module): diff --git a/examples/cifar/src/utils.py b/examples/cifar/src/utils.py index 079cb36..a5c6041 100644 --- a/examples/cifar/src/utils.py +++ b/examples/cifar/src/utils.py @@ -6,7 +6,7 @@ from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms import torchvision.transforms.functional as TF -from config.configuration import Config +from apeiron.config.configuration import Config from examples.cifar.src import cnns, vision_transformers diff --git a/examples/imagenet/model.py b/examples/imagenet/model.py index 2898b54..a5ffb51 100644 --- a/examples/imagenet/model.py +++ b/examples/imagenet/model.py @@ -5,8 +5,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, ConcatDataset -from model.torch_model_harness import BaseModelHarness -from config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness +from apeiron.config.configuration import Config from examples.imagenet.src.utils import ( get_imagenet_train, get_imagenet_val, @@ -16,7 +16,7 @@ sample_aug, ) from examples.imagenet.src.utils import load_model -from evaluation.metrics import accuracy +from apeiron.evaluation.metrics import accuracy class VisionModelImageNet(nn.Module): diff --git a/examples/imagenet/src/utils.py b/examples/imagenet/src/utils.py index bc59e0e..e41eff0 100644 --- a/examples/imagenet/src/utils.py +++ b/examples/imagenet/src/utils.py @@ -9,7 +9,7 @@ from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms import torchvision.transforms.functional as TF -from config.configuration import Config +from apeiron.config.configuration import Config from examples.cifar.src import cnns, vision_transformers diff --git a/examples/mnist/model.py b/examples/mnist/model.py index 094bae9..e8855d3 100644 --- a/examples/mnist/model.py +++ b/examples/mnist/model.py @@ -7,8 +7,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, ConcatDataset -from model.torch_model_harness import BaseModelHarness -from config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness +from apeiron.config.configuration import Config from examples.mnist.utils import ( get_mnist_train, get_mnist_val, @@ -17,7 +17,7 @@ make_loader, sample_aug, ) -from evaluation.metrics import accuracy +from apeiron.evaluation.metrics import accuracy class Cnn(nn.Module): diff --git a/examples/utils.py b/examples/utils.py index 1a1c413..0cde7b7 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,5 +1,5 @@ -from config.configuration import Config -from model.torch_model_harness import BaseModelHarness +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness def get_example(cfg: Config) -> BaseModelHarness: diff --git a/poetry.lock b/poetry.lock index 537d74c..731164b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5579,5 +5579,5 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" -python-versions = ">=3.13,<3.15" -content-hash = "fd24b8658a0fb13bf4f9871c26c6d7728ed2308fae15990eb1d2250121a0e34e" +python-versions = ">=3.13,<3.14" +content-hash = "560a1ea156905a760105a5941451fe91e802b813171d3d119574b2d8b49c0cb5" diff --git a/pyproject.toml b/pyproject.toml index 049eb40..b12cae8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,12 @@ [project] -name = "basesim-framework" +name = "apeiron" version = "0.1.0" -description = "" +description = "A PyTorch continual learning framework for real-time concept drift detection and model adaptation." authors = [ {name = "Your Name",email = "you@example.com"} ] readme = "README.md" -requires-python = ">=3.13,<3.15" +requires-python = ">=3.13,<3.14" dependencies = [ "torch (>=2.0)", "numpy (>=2.3.4,<3.0.0)", @@ -41,7 +41,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] packages = [ - { include = "*", from = "src" }, + { include = "apeiron", from = "src" }, ] [tool.ruff] diff --git a/src/apeiron/__init__.py b/src/apeiron/__init__.py new file mode 100644 index 0000000..d5026e7 --- /dev/null +++ b/src/apeiron/__init__.py @@ -0,0 +1,56 @@ +"""Apeiron: A PyTorch continual learning framework for real-time concept drift detection and model adaptation.""" + +from apeiron.config.configuration import ( + Config, + ModelCfg, + DataCfg, + TrainCfg, + ContinualLearningCfg, + DriftDetectionCfg, + VisualizationCfg, + LoggingCfg, + build_config, +) +from apeiron.model.torch_model_harness import BaseModelHarness +from apeiron.driver.continuous_monitor import ContinuousMonitor +from apeiron.drift_detection import ( + BaseDriftDetector, + DriftSignal, + LearningRegime, + ADWINDetector, + KSWINDetector, + PageHinkleyDetector, + ModelPerformanceDetector, + EnsembleDetector, + ModelEvalDetector, +) +from apeiron.training import ContinuousTrainer +from apeiron.training.updater import BaseUpdater +from apeiron.logger import Logger, get_logger + +__all__ = [ + "Config", + "ModelCfg", + "DataCfg", + "TrainCfg", + "ContinualLearningCfg", + "DriftDetectionCfg", + "VisualizationCfg", + "LoggingCfg", + "build_config", + "BaseModelHarness", + "ContinuousMonitor", + "BaseDriftDetector", + "DriftSignal", + "LearningRegime", + "ADWINDetector", + "KSWINDetector", + "PageHinkleyDetector", + "ModelPerformanceDetector", + "EnsembleDetector", + "ModelEvalDetector", + "ContinuousTrainer", + "BaseUpdater", + "Logger", + "get_logger", +] diff --git a/src/config/__init__.py b/src/apeiron/config/__init__.py similarity index 100% rename from src/config/__init__.py rename to src/apeiron/config/__init__.py diff --git a/src/config/configuration.py b/src/apeiron/config/configuration.py similarity index 99% rename from src/config/configuration.py rename to src/apeiron/config/configuration.py index 3bc3ece..5b5cb33 100644 --- a/src/config/configuration.py +++ b/src/apeiron/config/configuration.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from logger.logger import MetricsBackend + from apeiron.logger.logger import MetricsBackend def get_available_device(multi_gpu: bool = False) -> torch.device: diff --git a/src/deployment/__init__.py b/src/apeiron/deployment/__init__.py similarity index 100% rename from src/deployment/__init__.py rename to src/apeiron/deployment/__init__.py diff --git a/src/deployment/frontier/README.md b/src/apeiron/deployment/frontier/README.md similarity index 100% rename from src/deployment/frontier/README.md rename to src/apeiron/deployment/frontier/README.md diff --git a/src/deployment/frontier/install_venv.sh b/src/apeiron/deployment/frontier/install_venv.sh similarity index 100% rename from src/deployment/frontier/install_venv.sh rename to src/apeiron/deployment/frontier/install_venv.sh diff --git a/src/deployment/frontier/mnist_example.sbatch b/src/apeiron/deployment/frontier/mnist_example.sbatch similarity index 100% rename from src/deployment/frontier/mnist_example.sbatch rename to src/apeiron/deployment/frontier/mnist_example.sbatch diff --git a/src/deployment/perlmutter/README.md b/src/apeiron/deployment/perlmutter/README.md similarity index 100% rename from src/deployment/perlmutter/README.md rename to src/apeiron/deployment/perlmutter/README.md diff --git a/src/deployment/perlmutter/install_venv.sh b/src/apeiron/deployment/perlmutter/install_venv.sh similarity index 100% rename from src/deployment/perlmutter/install_venv.sh rename to src/apeiron/deployment/perlmutter/install_venv.sh diff --git a/src/deployment/perlmutter/mnist_example.sbatch b/src/apeiron/deployment/perlmutter/mnist_example.sbatch similarity index 100% rename from src/deployment/perlmutter/mnist_example.sbatch rename to src/apeiron/deployment/perlmutter/mnist_example.sbatch diff --git a/src/drift_detection/.gitkeep b/src/apeiron/drift_detection/.gitkeep similarity index 100% rename from src/drift_detection/.gitkeep rename to src/apeiron/drift_detection/.gitkeep diff --git a/src/drift_detection/__init__.py b/src/apeiron/drift_detection/__init__.py similarity index 84% rename from src/drift_detection/__init__.py rename to src/apeiron/drift_detection/__init__.py index 9597860..3cc95b7 100644 --- a/src/drift_detection/__init__.py +++ b/src/apeiron/drift_detection/__init__.py @@ -15,17 +15,17 @@ - EnsembleDetector: Combine multiple detectors """ -from drift_detection.detectors.base import ( +from apeiron.drift_detection.detectors.base import ( BaseDriftDetector, DriftSignal, LearningRegime, ) -from drift_detection.detectors.statistical_detectors import ( +from apeiron.drift_detection.detectors.statistical_detectors import ( ADWINDetector, KSWINDetector, PageHinkleyDetector, ) -from drift_detection.detectors.model_performance_detector import ( +from apeiron.drift_detection.detectors.model_performance_detector import ( ModelPerformanceDetector, EnsembleDetector, ModelEvalDetector, diff --git a/src/drift_detection/detectors/base.py b/src/apeiron/drift_detection/detectors/base.py similarity index 100% rename from src/drift_detection/detectors/base.py rename to src/apeiron/drift_detection/detectors/base.py diff --git a/src/drift_detection/detectors/model_performance_detector.py b/src/apeiron/drift_detection/detectors/model_performance_detector.py similarity index 99% rename from src/drift_detection/detectors/model_performance_detector.py rename to src/apeiron/drift_detection/detectors/model_performance_detector.py index ff78ef3..72cf1fd 100644 --- a/src/drift_detection/detectors/model_performance_detector.py +++ b/src/apeiron/drift_detection/detectors/model_performance_detector.py @@ -12,7 +12,7 @@ from typing import Optional, List from evidently import Report from evidently.presets import DataDriftPreset -from drift_detection.detectors.base import ( +from apeiron.drift_detection.detectors.base import ( BaseDriftDetector, DriftSignal, LearningRegime, diff --git a/src/drift_detection/detectors/statistical_detectors.py b/src/apeiron/drift_detection/detectors/statistical_detectors.py similarity index 99% rename from src/drift_detection/detectors/statistical_detectors.py rename to src/apeiron/drift_detection/detectors/statistical_detectors.py index 5b60717..4e24bf9 100644 --- a/src/drift_detection/detectors/statistical_detectors.py +++ b/src/apeiron/drift_detection/detectors/statistical_detectors.py @@ -10,7 +10,7 @@ import numpy as np from river import drift as river_drift -from drift_detection.detectors.base import ( +from apeiron.drift_detection.detectors.base import ( BaseDriftDetector, DriftSignal, LearningRegime, diff --git a/src/drift_detection/load_drift_detector.py b/src/apeiron/drift_detection/load_drift_detector.py similarity index 78% rename from src/drift_detection/load_drift_detector.py rename to src/apeiron/drift_detection/load_drift_detector.py index a5d97bb..9bee28b 100644 --- a/src/drift_detection/load_drift_detector.py +++ b/src/apeiron/drift_detection/load_drift_detector.py @@ -1,5 +1,5 @@ -from config.configuration import Config -from drift_detection.detectors.base import BaseDriftDetector +from apeiron.config.configuration import Config +from apeiron.drift_detection.detectors.base import BaseDriftDetector def load_drift_detector(cfg: Config) -> BaseDriftDetector: @@ -16,7 +16,7 @@ def load_drift_detector(cfg: Config) -> BaseDriftDetector: detector_instance: BaseDriftDetector if detector_name == "ADWINDetector": - from drift_detection.detectors.statistical_detectors import ADWINDetector + from apeiron.drift_detection.detectors.statistical_detectors import ADWINDetector detector_instance = ADWINDetector( delta=cfg.drift_detection.adwin_delta, @@ -24,7 +24,7 @@ def load_drift_detector(cfg: Config) -> BaseDriftDetector: moderate_threshold=cfg.drift_detection.adwin_moderate_threshold, ) elif detector_name == "KSWINDetector": - from drift_detection.detectors.statistical_detectors import KSWINDetector + from apeiron.drift_detection.detectors.statistical_detectors import KSWINDetector detector_instance = KSWINDetector( alpha=cfg.drift_detection.kswin_alpha, @@ -32,7 +32,7 @@ def load_drift_detector(cfg: Config) -> BaseDriftDetector: stat_size=cfg.drift_detection.kswin_stat_size, ) elif detector_name == "PageHinkleyDetector": - from drift_detection.detectors.statistical_detectors import ( + from apeiron.drift_detection.detectors.statistical_detectors import ( PageHinkleyDetector, ) @@ -43,7 +43,7 @@ def load_drift_detector(cfg: Config) -> BaseDriftDetector: alpha=cfg.drift_detection.ph_alpha, ) elif detector_name == "ModelPerformanceDetector": - from drift_detection.detectors.model_performance_detector import ( + from apeiron.drift_detection.detectors.model_performance_detector import ( ModelPerformanceDetector, ) @@ -55,14 +55,14 @@ def load_drift_detector(cfg: Config) -> BaseDriftDetector: "PageHinkleyDetector, or ModelPerformanceDetector instead." ) - # from drift_detection.detectors.model_performance_detector import ( + # from apeiron.drift_detection.detectors.model_performance_detector import ( # EnsembleDetector, # ) # detector_instance = EnsembleDetector() elif detector_name == "EvalDetector": - from drift_detection.detectors.model_performance_detector import ( + from apeiron.drift_detection.detectors.model_performance_detector import ( ModelEvalDetector, ) diff --git a/src/apeiron/driver/__init__.py b/src/apeiron/driver/__init__.py new file mode 100644 index 0000000..c8a6bff --- /dev/null +++ b/src/apeiron/driver/__init__.py @@ -0,0 +1 @@ +from apeiron.driver.continuous_monitor import ContinuousMonitor as ContinuousMonitor diff --git a/src/driver/continuous_monitor.py b/src/apeiron/driver/continuous_monitor.py similarity index 97% rename from src/driver/continuous_monitor.py rename to src/apeiron/driver/continuous_monitor.py index 217287c..e095502 100644 --- a/src/driver/continuous_monitor.py +++ b/src/apeiron/driver/continuous_monitor.py @@ -14,16 +14,16 @@ import torch import numpy as np -from config.configuration import Config -from drift_detection.load_drift_detector import load_drift_detector -from drift_detection.detectors.base import DriftSignal -from profilers import FLOPSProfiler -from logger import get_logger -from training import ContinuousTrainer +from apeiron.config.configuration import Config +from apeiron.drift_detection.load_drift_detector import load_drift_detector +from apeiron.drift_detection.detectors.base import DriftSignal +from apeiron.profilers import FLOPSProfiler +from apeiron.logger import get_logger +from apeiron.training import ContinuousTrainer from tqdm import tqdm if TYPE_CHECKING: - from model.torch_model_harness import BaseModelHarness + from apeiron.model.torch_model_harness import BaseModelHarness class ContinuousMonitor: diff --git a/src/evaluation/__init__.py b/src/apeiron/evaluation/__init__.py similarity index 100% rename from src/evaluation/__init__.py rename to src/apeiron/evaluation/__init__.py diff --git a/src/evaluation/evaluation.py b/src/apeiron/evaluation/evaluation.py similarity index 100% rename from src/evaluation/evaluation.py rename to src/apeiron/evaluation/evaluation.py diff --git a/src/evaluation/metrics.py b/src/apeiron/evaluation/metrics.py similarity index 100% rename from src/evaluation/metrics.py rename to src/apeiron/evaluation/metrics.py diff --git a/src/logger/__init__.py b/src/apeiron/logger/__init__.py similarity index 61% rename from src/logger/__init__.py rename to src/apeiron/logger/__init__.py index 5231e3e..dff4878 100644 --- a/src/logger/__init__.py +++ b/src/apeiron/logger/__init__.py @@ -1,15 +1,15 @@ """Logging utilities for BaseSim Framework.""" -from logger.logger import ( +from apeiron.logger.logger import ( Logger, get_logger, reset_logger, configure_backend, MetricsBackend, ) -from logger.wandb_logger import WandBLogger, StageType, VALID_STAGES -from logger.mlflow_logger import MLFlowLogger -from logger.console_logger import ConsoleLogger +from apeiron.logger.wandb_logger import WandBLogger, StageType, VALID_STAGES +from apeiron.logger.mlflow_logger import MLFlowLogger +from apeiron.logger.console_logger import ConsoleLogger __all__ = [ "Logger", diff --git a/src/logger/console_logger.py b/src/apeiron/logger/console_logger.py similarity index 100% rename from src/logger/console_logger.py rename to src/apeiron/logger/console_logger.py diff --git a/src/logger/logger.py b/src/apeiron/logger/logger.py similarity index 95% rename from src/logger/logger.py rename to src/apeiron/logger/logger.py index 2343387..61ae49a 100644 --- a/src/logger/logger.py +++ b/src/apeiron/logger/logger.py @@ -6,10 +6,10 @@ from pathlib import Path from typing import Any, Literal -from config.configuration import Config -from logger.console_logger import ConsoleLogger -from logger.wandb_logger import StageType, WandBLogger -from logger.mlflow_logger import MLFlowLogger +from apeiron.config.configuration import Config +from apeiron.logger.console_logger import ConsoleLogger +from apeiron.logger.wandb_logger import StageType, WandBLogger +from apeiron.logger.mlflow_logger import MLFlowLogger # Type alias for metrics backend @@ -195,7 +195,7 @@ def reset_logger() -> None: def configure_backend(cfg: Config | None) -> MetricsBackend: - """Configure and return the logging backend from config.""" + """Configure and return the logging backend from apeiron.config.""" if cfg is None or cfg.logging is None: return "wandb" diff --git a/src/logger/mlflow_logger.py b/src/apeiron/logger/mlflow_logger.py similarity index 98% rename from src/logger/mlflow_logger.py rename to src/apeiron/logger/mlflow_logger.py index 0f4ada5..714aa1d 100644 --- a/src/logger/mlflow_logger.py +++ b/src/apeiron/logger/mlflow_logger.py @@ -7,11 +7,11 @@ from pathlib import Path from typing import Any, Literal, TYPE_CHECKING -from logger.wandb_logger import StageType, VALID_STAGES +from apeiron.logger.wandb_logger import StageType, VALID_STAGES if TYPE_CHECKING: import mlflow - from config.configuration import Config + from apeiron.config.configuration import Config class MLFlowLogger: diff --git a/src/logger/wandb_logger.py b/src/apeiron/logger/wandb_logger.py similarity index 99% rename from src/logger/wandb_logger.py rename to src/apeiron/logger/wandb_logger.py index c7c0c7c..20e8297 100644 --- a/src/logger/wandb_logger.py +++ b/src/apeiron/logger/wandb_logger.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Any, Literal -from config.configuration import Config +from apeiron.config.configuration import Config StageType = Literal["eval", "drift", "cl"] diff --git a/src/apeiron/model/__init__.py b/src/apeiron/model/__init__.py new file mode 100644 index 0000000..1b5859f --- /dev/null +++ b/src/apeiron/model/__init__.py @@ -0,0 +1 @@ +from apeiron.model.torch_model_harness import BaseModelHarness as BaseModelHarness diff --git a/src/model/torch_model_harness.py b/src/apeiron/model/torch_model_harness.py similarity index 99% rename from src/model/torch_model_harness.py rename to src/apeiron/model/torch_model_harness.py index fee7a53..1148ec1 100644 --- a/src/model/torch_model_harness.py +++ b/src/apeiron/model/torch_model_harness.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from torch.optim import Optimizer -from config.configuration import Config +from apeiron.config.configuration import Config MetricFn = Callable[[Tensor, Tensor], Any] CriterionFn = Callable[[Tensor, Tensor], Tensor] diff --git a/src/profilers/README.md b/src/apeiron/profilers/README.md similarity index 100% rename from src/profilers/README.md rename to src/apeiron/profilers/README.md diff --git a/src/apeiron/profilers/__init__.py b/src/apeiron/profilers/__init__.py new file mode 100644 index 0000000..9cde9c1 --- /dev/null +++ b/src/apeiron/profilers/__init__.py @@ -0,0 +1 @@ +from apeiron.profilers.count_flops import FLOPSProfiler as FLOPSProfiler diff --git a/src/profilers/aten_flops_map.py b/src/apeiron/profilers/aten_flops_map.py similarity index 100% rename from src/profilers/aten_flops_map.py rename to src/apeiron/profilers/aten_flops_map.py diff --git a/src/profilers/count_flops.py b/src/apeiron/profilers/count_flops.py similarity index 99% rename from src/profilers/count_flops.py rename to src/apeiron/profilers/count_flops.py index a6f3892..7573237 100644 --- a/src/profilers/count_flops.py +++ b/src/apeiron/profilers/count_flops.py @@ -36,7 +36,7 @@ from torch.profiler import profile, ProfilerActivity from torch.utils.flop_counter import FlopCounterMode -from profilers.aten_flops_map import ATEN_FLOPS_PER_ELEMENT +from apeiron.profilers.aten_flops_map import ATEN_FLOPS_PER_ELEMENT # - diff --git a/src/apeiron/training/__init__.py b/src/apeiron/training/__init__.py new file mode 100644 index 0000000..1230275 --- /dev/null +++ b/src/apeiron/training/__init__.py @@ -0,0 +1 @@ +from apeiron.training.continuous_trainer import ContinuousTrainer as ContinuousTrainer diff --git a/src/training/continuous_trainer.py b/src/apeiron/training/continuous_trainer.py similarity index 96% rename from src/training/continuous_trainer.py rename to src/apeiron/training/continuous_trainer.py index 15082cc..09e4697 100644 --- a/src/training/continuous_trainer.py +++ b/src/apeiron/training/continuous_trainer.py @@ -7,11 +7,11 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from config.configuration import Config -from model.torch_model_harness import BaseModelHarness -from profilers import FLOPSProfiler -from training.updater.create_updater import create_updater -from logger import get_logger +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness +from apeiron.profilers import FLOPSProfiler +from apeiron.training.updater.create_updater import create_updater +from apeiron.logger import get_logger class ContinuousTrainer: diff --git a/src/apeiron/training/updater/__init__.py b/src/apeiron/training/updater/__init__.py new file mode 100644 index 0000000..42357db --- /dev/null +++ b/src/apeiron/training/updater/__init__.py @@ -0,0 +1 @@ +from apeiron.training.updater.base import BaseUpdater as BaseUpdater diff --git a/src/training/updater/base.py b/src/apeiron/training/updater/base.py similarity index 94% rename from src/training/updater/base.py rename to src/apeiron/training/updater/base.py index c3ae5d8..206476d 100644 --- a/src/training/updater/base.py +++ b/src/apeiron/training/updater/base.py @@ -5,8 +5,8 @@ import torch import torch.nn as nn -from config.configuration import Config -from model.torch_model_harness import BaseModelHarness +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness class BaseUpdater: diff --git a/src/training/updater/create_updater.py b/src/apeiron/training/updater/create_updater.py similarity index 73% rename from src/training/updater/create_updater.py rename to src/apeiron/training/updater/create_updater.py index 2463c4f..37e7d76 100644 --- a/src/training/updater/create_updater.py +++ b/src/apeiron/training/updater/create_updater.py @@ -1,8 +1,8 @@ from __future__ import annotations -from config.configuration import Config -from model.torch_model_harness import BaseModelHarness -from training.updater.base import BaseUpdater +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness +from apeiron.training.updater.base import BaseUpdater def create_updater(cfg: Config, modelHarness: BaseModelHarness) -> BaseUpdater: @@ -22,22 +22,22 @@ def create_updater(cfg: Config, modelHarness: BaseModelHarness) -> BaseUpdater: return BaseUpdater(cfg=cfg, modelHarness=modelHarness) if cfg.continual_learning.update_mode == "ewc_online": - from training.updater.ewc import OnlineEWCUpdater + from apeiron.training.updater.ewc import OnlineEWCUpdater return OnlineEWCUpdater(cfg=cfg, modelHarness=modelHarness) if cfg.continual_learning.update_mode == "kfac_online": - from training.updater.kfac import OnlineKFACUpdater + from apeiron.training.updater.kfac import OnlineKFACUpdater return OnlineKFACUpdater(cfg=cfg, modelHarness=modelHarness) if cfg.continual_learning.update_mode == "jvp_reg": - from training.updater.jvp_reg import JVPRegUpdater + from apeiron.training.updater.jvp_reg import JVPRegUpdater return JVPRegUpdater(cfg=cfg, modelHarness=modelHarness) if cfg.continual_learning.update_mode == "none": - from training.updater.no_updater import NoUpdater + from apeiron.training.updater.no_updater import NoUpdater return NoUpdater(cfg=cfg, modelHarness=modelHarness) diff --git a/src/training/updater/ewc.py b/src/apeiron/training/updater/ewc.py similarity index 96% rename from src/training/updater/ewc.py rename to src/apeiron/training/updater/ewc.py index e573a93..d418184 100644 --- a/src/training/updater/ewc.py +++ b/src/apeiron/training/updater/ewc.py @@ -1,8 +1,8 @@ import torch -from training.updater.base import BaseUpdater -from config.configuration import Config -from model.torch_model_harness import BaseModelHarness +from apeiron.training.updater.base import BaseUpdater +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness class OnlineEWCUpdater(BaseUpdater): diff --git a/src/training/updater/jvp_reg.py b/src/apeiron/training/updater/jvp_reg.py similarity index 96% rename from src/training/updater/jvp_reg.py rename to src/apeiron/training/updater/jvp_reg.py index 5c4f19b..0f5584c 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/apeiron/training/updater/jvp_reg.py @@ -11,9 +11,9 @@ import torch from torch.func import functional_call, grad, jvp -from config.configuration import Config -from model.torch_model_harness import BaseModelHarness -from training.updater.base import BaseUpdater +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness +from apeiron.training.updater.base import BaseUpdater class JVPRegUpdater(BaseUpdater): diff --git a/src/training/updater/kfac.py b/src/apeiron/training/updater/kfac.py similarity index 97% rename from src/training/updater/kfac.py rename to src/apeiron/training/updater/kfac.py index af7a11e..bd466f8 100644 --- a/src/training/updater/kfac.py +++ b/src/apeiron/training/updater/kfac.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn -from training.updater.base import BaseUpdater -from config.configuration import Config -from model.torch_model_harness import BaseModelHarness +from apeiron.training.updater.base import BaseUpdater +from apeiron.config.configuration import Config +from apeiron.model.torch_model_harness import BaseModelHarness import warnings warnings.filterwarnings("ignore", "Full backward hook is firing") diff --git a/src/training/updater/no_updater.py b/src/apeiron/training/updater/no_updater.py similarity index 89% rename from src/training/updater/no_updater.py rename to src/apeiron/training/updater/no_updater.py index 76718df..21ee7cb 100644 --- a/src/training/updater/no_updater.py +++ b/src/apeiron/training/updater/no_updater.py @@ -1,7 +1,7 @@ from __future__ import annotations -from training.updater.base import BaseUpdater +from apeiron.training.updater.base import BaseUpdater import torch diff --git a/src/main.py b/src/main.py index ef4dfa6..00cf1cf 100644 --- a/src/main.py +++ b/src/main.py @@ -1,11 +1,11 @@ import sys -from logger import get_logger, configure_backend -from config.configuration import build_config, Config +from apeiron.logger import get_logger, configure_backend +from apeiron.config.configuration import build_config, Config from examples.utils import get_example -from driver.continuous_monitor import ContinuousMonitor +from apeiron.driver.continuous_monitor import ContinuousMonitor def main(argv: list[str] | None = None) -> int: diff --git a/src/profilers/__init__.py b/src/profilers/__init__.py deleted file mode 100644 index 95d47f2..0000000 --- a/src/profilers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from profilers.count_flops import FLOPSProfiler as FLOPSProfiler diff --git a/src/training/__init__.py b/src/training/__init__.py deleted file mode 100644 index c43837b..0000000 --- a/src/training/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from training.continuous_trainer import ContinuousTrainer as ContinuousTrainer diff --git a/src/training/updater/__init__.py b/src/training/updater/__init__.py deleted file mode 100644 index 61017dc..0000000 --- a/src/training/updater/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from training.updater.base import BaseUpdater as BaseUpdater diff --git a/tests/conftest.py b/tests/conftest.py index 82386c2..0b69466 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ from torch.optim import SGD from torch.utils.data import DataLoader, TensorDataset -from config.configuration import ( +from apeiron.config.configuration import ( Config, ContinualLearningCfg, DataCfg, @@ -18,7 +18,7 @@ ModelCfg, TrainCfg, ) -from model.torch_model_harness import BaseModelHarness +from apeiron.model.torch_model_harness import BaseModelHarness # --------------------------------------------------------------------------- @@ -80,7 +80,7 @@ def __init__( self._hist_train_ds = hist_train_data self._hist_val_ds = hist_val_data - from evaluation.metrics import accuracy + from apeiron.evaluation.metrics import accuracy self.eval_metrics = {"accuracy": accuracy} diff --git a/tests/test_config.py b/tests/test_config.py index f6b59cf..f0fabc2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,7 +10,7 @@ import pytest import torch -from config.configuration import ( +from apeiron.config.configuration import ( Config, ContinualLearningCfg, DriftDetectionCfg, @@ -287,7 +287,7 @@ def test_visualization_cfg(self): class TestDeviceSelection: def test_select_best_gpu_no_nvidia_smi(self): with patch( - "config.configuration.subprocess.check_output", + "apeiron.config.configuration.subprocess.check_output", side_effect=FileNotFoundError, ): assert _select_best_gpu() is None @@ -295,13 +295,13 @@ def test_select_best_gpu_no_nvidia_smi(self): def test_select_best_gpu_with_output(self): fake_output = b"1000\n2000\n500\n" with patch( - "config.configuration.subprocess.check_output", return_value=fake_output + "apeiron.config.configuration.subprocess.check_output", return_value=fake_output ): assert _select_best_gpu() == 1 # index of 2000 def test_get_available_device_cpu_fallback(self): with ( - patch("config.configuration._select_best_gpu", return_value=None), + patch("apeiron.config.configuration._select_best_gpu", return_value=None), patch("torch.cuda.is_available", return_value=False), patch.object(torch.backends, "mps", create=True) as mock_mps, ): diff --git a/tests/test_continuous_monitor.py b/tests/test_continuous_monitor.py index d84aaaa..c46a86c 100644 --- a/tests/test_continuous_monitor.py +++ b/tests/test_continuous_monitor.py @@ -7,8 +7,8 @@ import pytest -from drift_detection.detectors.base import DriftSignal, LearningRegime -from driver.continuous_monitor import ContinuousMonitor +from apeiron.drift_detection.detectors.base import DriftSignal, LearningRegime +from apeiron.driver.continuous_monitor import ContinuousMonitor # We patch get_logger globally for these tests since ContinuousMonitor calls it in __init__ @@ -16,8 +16,8 @@ def _patch_logger(): mock_logger = MagicMock() mock_logger.step = 0 - with patch("driver.continuous_monitor.get_logger", return_value=mock_logger): - with patch("training.continuous_trainer.get_logger", return_value=mock_logger): + with patch("apeiron.driver.continuous_monitor.get_logger", return_value=mock_logger): + with patch("apeiron.training.continuous_trainer.get_logger", return_value=mock_logger): yield mock_logger diff --git a/tests/test_continuous_trainer.py b/tests/test_continuous_trainer.py index 5e703b1..b623768 100644 --- a/tests/test_continuous_trainer.py +++ b/tests/test_continuous_trainer.py @@ -9,14 +9,14 @@ import torch from torch.utils.data import DataLoader, TensorDataset -from training.continuous_trainer import ContinuousTrainer +from apeiron.training.continuous_trainer import ContinuousTrainer @pytest.fixture(autouse=True) def _patch_logger(): mock_logger = MagicMock() mock_logger.step = 0 - with patch("training.continuous_trainer.get_logger", return_value=mock_logger): + with patch("apeiron.training.continuous_trainer.get_logger", return_value=mock_logger): yield mock_logger diff --git a/tests/test_drift_detection.py b/tests/test_drift_detection.py index d299743..1fec30c 100644 --- a/tests/test_drift_detection.py +++ b/tests/test_drift_detection.py @@ -5,22 +5,22 @@ import numpy as np import pytest -from drift_detection.detectors.base import ( +from apeiron.drift_detection.detectors.base import ( DriftSignal, LearningRegime, ) -from drift_detection.detectors.statistical_detectors import ( +from apeiron.drift_detection.detectors.statistical_detectors import ( ADWINDetector, KSWINDetector, PageHinkleyDetector, ) -from drift_detection.detectors.model_performance_detector import ( +from apeiron.drift_detection.detectors.model_performance_detector import ( EnsembleDetector, ModelEvalDetector, ModelPerformanceDetector, ) -from config.configuration import DriftDetectionCfg -from drift_detection.load_drift_detector import load_drift_detector +from apeiron.config.configuration import DriftDetectionCfg +from apeiron.drift_detection.load_drift_detector import load_drift_detector # --------------------------------------------------------------------------- diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index e77ee49..5545c69 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -5,7 +5,7 @@ import torch import pytest -from evaluation.metrics import accuracy +from apeiron.evaluation.metrics import accuracy class TestAccuracy: diff --git a/tests/test_logger.py b/tests/test_logger.py index efb3be3..429cf3d 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -6,8 +6,8 @@ import pytest -from logger.console_logger import ConsoleLogger, ColoredFormatter, StepFilter -from logger.wandb_logger import WandBLogger, VALID_STAGES +from apeiron.logger.console_logger import ConsoleLogger, ColoredFormatter, StepFilter +from apeiron.logger.wandb_logger import WandBLogger, VALID_STAGES # --------------------------------------------------------------------------- diff --git a/tests/test_model_harness.py b/tests/test_model_harness.py index 70315dc..9f36225 100644 --- a/tests/test_model_harness.py +++ b/tests/test_model_harness.py @@ -5,7 +5,7 @@ import pytest import torch -from model.torch_model_harness import BaseModelHarness +from apeiron.model.torch_model_harness import BaseModelHarness class TestUnpack: diff --git a/tests/test_profiler.py b/tests/test_profiler.py index d6e321b..dae2b40 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from profilers.count_flops import FLOPSProfiler +from apeiron.profilers.count_flops import FLOPSProfiler # TODO: throughout this, mostly checking that the outputs are of a form that we expect. We should also verify results diff --git a/tests/test_updaters.py b/tests/test_updaters.py index b950e38..e9efddb 100644 --- a/tests/test_updaters.py +++ b/tests/test_updaters.py @@ -8,13 +8,13 @@ import torch import torch.nn as nn -from config.configuration import ContinualLearningCfg -from training.updater.base import BaseUpdater -from training.updater.no_updater import NoUpdater -from training.updater.ewc import OnlineEWCUpdater -from training.updater.kfac import OnlineKFACUpdater -from training.updater.jvp_reg import JVPRegUpdater -from training.updater.create_updater import create_updater +from apeiron.config.configuration import ContinualLearningCfg +from apeiron.training.updater.base import BaseUpdater +from apeiron.training.updater.no_updater import NoUpdater +from apeiron.training.updater.ewc import OnlineEWCUpdater +from apeiron.training.updater.kfac import OnlineKFACUpdater +from apeiron.training.updater.jvp_reg import JVPRegUpdater +from apeiron.training.updater.create_updater import create_updater # --------------------------------------------------------------------------- diff --git a/tests/test_valiadation_tests.py b/tests/test_valiadation_tests.py index 557e3b2..2bec182 100644 --- a/tests/test_valiadation_tests.py +++ b/tests/test_valiadation_tests.py @@ -12,11 +12,11 @@ import pytest import torch -from config.configuration import build_config -from drift_detection.detectors.base import DriftSignal, LearningRegime -from driver.continuous_monitor import ContinuousMonitor -from logger import get_logger -import logger.logger as logger_module +from apeiron.config.configuration import build_config +from apeiron.drift_detection.detectors.base import DriftSignal, LearningRegime +from apeiron.driver.continuous_monitor import ContinuousMonitor +from apeiron.logger import get_logger +import apeiron.logger.logger as logger_module PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: