diff --git a/tests/_build_legacy_model/.gitignore b/tests/_build_legacy_model/.gitignore new file mode 100644 index 00000000..4b6ebe5f --- /dev/null +++ b/tests/_build_legacy_model/.gitignore @@ -0,0 +1 @@ +*.pt diff --git a/tests/_build_legacy_model/Dockerfile b/tests/_build_legacy_model/Dockerfile new file mode 100644 index 00000000..ddbb0e61 --- /dev/null +++ b/tests/_build_legacy_model/Dockerfile @@ -0,0 +1,39 @@ +FROM python:3.12-slim AS base +RUN pip install torch --index-url https://download.pytorch.org/whl/cpu +RUN apt-get update && \ + apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* + +FROM base AS cebra-0.4.0-scikit-learn-1.4 +RUN pip install cebra==0.4.0 "scikit-learn<1.5" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-0.4.0-scikit-learn-1.6 +RUN pip install cebra==0.4.0 "scikit-learn>=1.6" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-rc-scikit-learn-1.4 +# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class. +# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053 +RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn<1.5" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-rc-scikit-learn-1.6 +# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class. +# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053 +RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn>=1.6" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM scratch +COPY --from=cebra-0.4.0-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.4.pt +COPY --from=cebra-0.4.0-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.6.pt +COPY --from=cebra-rc-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.4.pt +COPY --from=cebra-rc-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.6.pt diff --git a/tests/_build_legacy_model/README.md b/tests/_build_legacy_model/README.md new file mode 100644 index 00000000..4bcffa2b --- /dev/null +++ b/tests/_build_legacy_model/README.md @@ -0,0 +1,13 @@ +# Helper script to build CEBRA checkpoints + +This script builds CEBRA checkpoints for different versions of scikit-learn and CEBRA. +To build all models, run: + +```bash +./generate.sh +``` + +The models are currently also stored in git directly due to their small size. + +Related issue: https://github.com/AdaptiveMotorControlLab/CEBRA/issues/207 +Related test: tests/test_sklearn_legacy.py diff --git a/tests/_build_legacy_model/create_model.py b/tests/_build_legacy_model/create_model.py new file mode 100644 index 00000000..f308d296 --- /dev/null +++ b/tests/_build_legacy_model/create_model.py @@ -0,0 +1,15 @@ +import numpy as np + +import cebra + +neural_data = np.random.normal(0, 1, (1000, 30)) # 1000 samples, 30 features +cebra_model = cebra.CEBRA(model_architecture="offset10-model", + batch_size=512, + learning_rate=1e-4, + max_iterations=10, + time_offsets=10, + num_hidden_units=16, + output_dimension=8, + verbose=True) +cebra_model.fit(neural_data) +cebra_model.save("cebra_model.pt") diff --git a/tests/_build_legacy_model/generate.sh b/tests/_build_legacy_model/generate.sh new file mode 100755 index 00000000..749a0d32 --- /dev/null +++ b/tests/_build_legacy_model/generate.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +DOCKER_BUILDKIT=1 docker build --output type=local,dest=. . diff --git a/tests/test_sklearn_legacy.py b/tests/test_sklearn_legacy.py new file mode 100644 index 00000000..4d74515f --- /dev/null +++ b/tests/test_sklearn_legacy.py @@ -0,0 +1,41 @@ +import pathlib +import urllib.request + +import numpy as np +import pytest + +from cebra.integrations.sklearn.cebra import CEBRA + +MODEL_VARIANTS = [ + "cebra-0.4.0-scikit-learn-1.4", "cebra-0.4.0-scikit-learn-1.6", + "cebra-rc-scikit-learn-1.4", "cebra-rc-scikit-learn-1.6" +] + + +@pytest.mark.parametrize("model_variant", MODEL_VARIANTS) +def test_load_legacy_model(model_variant): + """Test loading a legacy CEBRA model.""" + + X = np.random.normal(0, 1, (1000, 30)) + + model_path = pathlib.Path( + __file__ + ).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt" + + if not model_path.exists(): + url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt" + urllib.request.urlretrieve(url, model_path) + + loaded_model = CEBRA.load(model_path) + + assert loaded_model.model_architecture == "offset10-model" + assert loaded_model.output_dimension == 8 + assert loaded_model.num_hidden_units == 16 + assert loaded_model.time_offsets == 10 + + output = loaded_model.transform(X) + assert isinstance(output, np.ndarray) + assert output.shape[1] == loaded_model.output_dimension + + assert hasattr(loaded_model, "state_dict_") + assert hasattr(loaded_model, "n_features_")