Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/_build_legacy_model/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pt
39 changes: 39 additions & 0 deletions tests/_build_legacy_model/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
FROM python:3.12-slim AS base
RUN pip install torch --index-url https://download.pytorch.org/whl/cpu
RUN apt-get update && \
apt-get install -y --no-install-recommends git && \
rm -rf /var/lib/apt/lists/*

FROM base AS cebra-0.4.0-scikit-learn-1.4
RUN pip install cebra==0.4.0 "scikit-learn<1.5"
WORKDIR /app
COPY create_model.py .
RUN python create_model.py

FROM base AS cebra-0.4.0-scikit-learn-1.6
RUN pip install cebra==0.4.0 "scikit-learn>=1.6"
WORKDIR /app
COPY create_model.py .
RUN python create_model.py

FROM base AS cebra-rc-scikit-learn-1.4
# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class.
# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053
RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn<1.5"
WORKDIR /app
COPY create_model.py .
RUN python create_model.py

FROM base AS cebra-rc-scikit-learn-1.6
# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class.
# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053
RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn>=1.6"
WORKDIR /app
COPY create_model.py .
RUN python create_model.py

FROM scratch
COPY --from=cebra-0.4.0-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.4.pt
COPY --from=cebra-0.4.0-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.6.pt
COPY --from=cebra-rc-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.4.pt
COPY --from=cebra-rc-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.6.pt
13 changes: 13 additions & 0 deletions tests/_build_legacy_model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Helper script to build CEBRA checkpoints

This script builds CEBRA checkpoints for different versions of scikit-learn and CEBRA.
To build all models, run:

```bash
./generate.sh
```

The models are currently also stored in git directly due to their small size.

Related issue: https://github.com/AdaptiveMotorControlLab/CEBRA/issues/207
Related test: tests/test_sklearn_legacy.py
15 changes: 15 additions & 0 deletions tests/_build_legacy_model/create_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np

import cebra

neural_data = np.random.normal(0, 1, (1000, 30)) # 1000 samples, 30 features
cebra_model = cebra.CEBRA(model_architecture="offset10-model",
batch_size=512,
learning_rate=1e-4,
max_iterations=10,
time_offsets=10,
num_hidden_units=16,
output_dimension=8,
verbose=True)
cebra_model.fit(neural_data)
cebra_model.save("cebra_model.pt")
3 changes: 3 additions & 0 deletions tests/_build_legacy_model/generate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

DOCKER_BUILDKIT=1 docker build --output type=local,dest=. .
41 changes: 41 additions & 0 deletions tests/test_sklearn_legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pathlib
import urllib.request

import numpy as np
import pytest

from cebra.integrations.sklearn.cebra import CEBRA

MODEL_VARIANTS = [
"cebra-0.4.0-scikit-learn-1.4", "cebra-0.4.0-scikit-learn-1.6",
"cebra-rc-scikit-learn-1.4", "cebra-rc-scikit-learn-1.6"
]


@pytest.mark.parametrize("model_variant", MODEL_VARIANTS)
def test_load_legacy_model(model_variant):
"""Test loading a legacy CEBRA model."""

X = np.random.normal(0, 1, (1000, 30))

model_path = pathlib.Path(
__file__
).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt"

if not model_path.exists():
url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt"
urllib.request.urlretrieve(url, model_path)

loaded_model = CEBRA.load(model_path)

assert loaded_model.model_architecture == "offset10-model"
assert loaded_model.output_dimension == 8
assert loaded_model.num_hidden_units == 16
assert loaded_model.time_offsets == 10

output = loaded_model.transform(X)
assert isinstance(output, np.ndarray)
assert output.shape[1] == loaded_model.output_dimension

assert hasattr(loaded_model, "state_dict_")
assert hasattr(loaded_model, "n_features_")
Loading