diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 39f882b9..6771243d 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -47,7 +47,11 @@ jobs: with: repository: AdaptiveMotorControlLab/cebra-demos path: docs/source/demo_notebooks - ref: main + # NOTE(stes): This is a temporary branch to add the xCEBRA demo notebooks + # to the docs. Once the notebooks are merged into main, we can remove this + # branch and change the ref to main. + # ref: main + ref: stes/add-xcebra - name: Set up Python 3.10 uses: actions/setup-python@v5 diff --git a/.gitignore b/.gitignore index 0563e474..e30f5f43 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,16 @@ experiments/sweeps exports/ demo_notebooks/ assets/ +.remove + +# demo run +.vscode/ +auxiliary_behavior_data.h5 +cebra_model.pt +data.npz +grid_search_models/ +neural_data.npz +saved_models/ # demo run .vscode/ diff --git a/Dockerfile b/Dockerfile index 46c8a555..1a280a30 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,7 @@ RUN make dist FROM cebra-base # install the cebra wheel -ENV WHEEL=cebra-0.5.0-py3-none-any.whl +ENV WHEEL=cebra-0.6.0a1-py3-none-any.whl WORKDIR /build COPY --from=wheel /build/dist/${WHEEL} . RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]' diff --git a/Makefile b/Makefile index 5b8cb107..a863a921 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CEBRA_VERSION := 0.5.0 +CEBRA_VERSION := 0.6.0a1 dist: python3 -m pip install virtualenv @@ -55,7 +55,7 @@ interrogate: --ignore-private \ --ignore-magic \ --omit-covered-files \ - -f 90 \ + -f 80 \ cebra # Build documentation using sphinx diff --git a/NOTICE.yml b/NOTICE.yml index 3588b5e6..bf498e0f 100644 --- a/NOTICE.yml +++ b/NOTICE.yml @@ -35,3 +35,83 @@ - 'tests/**/*.py' - 'docs/**/*.py' - 'conda/**/*.yml' + +- header: | + CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables + © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) + Source code: + https://github.com/AdaptiveMotorControlLab/CEBRA + + Please see LICENSE.md for the full license document: + https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md + + Adapted from https://github.com/rpatrik96/nl-causal-representations/blob/master/care_nl_ica/dep_mat.py, + licensed under the following MIT License: + + MIT License + + Copyright (c) 2022 Patrik Reizinger + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + include: + - 'cebra/attribution/jacobian.py' + + +- header: | + CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables + © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) + Source code: + https://github.com/AdaptiveMotorControlLab/CEBRA + + Please see LICENSE.md for the full license document: + https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md + + This file contains the PyTorch implementation of Jacobian regularization described in [1]. + Judy Hoffman, Daniel A. Roberts, and Sho Yaida, + "Robust Learning with Jacobian Regularization," 2019. + [arxiv:1908.02729](https://arxiv.org/abs/1908.02729) + + Adapted from https://github.com/facebookresearch/jacobian_regularizer/blob/main/jacobian/jacobian.py + licensed under the following MIT License: + + MIT License + + Copyright (c) Facebook, Inc. and its affiliates. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + include: + - 'cebra/models/jacobian_regularizer.py' diff --git a/PKGBUILD b/PKGBUILD index 7aa985a8..48088dcb 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -1,7 +1,7 @@ # Maintainer: Steffen Schneider pkgname=python-cebra _pkgname=cebra -pkgver=0.5.0 +pkgver=0.6.0a1 pkgrel=1 pkgdesc="Consistent Embeddings of high-dimensional Recordings using Auxiliary variables" url="https://cebra.ai" diff --git a/cebra/__init__.py b/cebra/__init__.py index 0eb1f645..cb2cbd06 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -66,7 +66,7 @@ import cebra.integrations.sklearn as sklearn -__version__ = "0.5.0" +__version__ = "0.6.0a1" __all__ = ["CEBRA"] __allow_lazy_imports = False __lazy_imports = {} diff --git a/cebra/attribution/__init__.py b/cebra/attribution/__init__.py new file mode 100644 index 00000000..e1d8306a --- /dev/null +++ b/cebra/attribution/__init__.py @@ -0,0 +1,38 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Attribution methods for CEBRA. + +This module was added in v0.6.0 and contains attribution methods described and benchmarked +in [Schneider2025]_. + + +.. [Schneider2025] Schneider, S., González Laiz, R., Filippova, A., Frey, M., & Mathis, M. W. (2025). + Time-series attribution maps with regularized contrastive learning. + The 28th International Conference on Artificial Intelligence and Statistics. + https://openreview.net/forum?id=aGrCXoTB4P +""" +import cebra.registry + +cebra.registry.add_helper_functions(__name__) + +from cebra.attribution.attribution_models import * +from cebra.attribution.jacobian_attribution import * diff --git a/cebra/attribution/_jacobian.py b/cebra/attribution/_jacobian.py new file mode 100644 index 00000000..00102aeb --- /dev/null +++ b/cebra/attribution/_jacobian.py @@ -0,0 +1,142 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Adapted from https://github.com/rpatrik96/nl-causal-representations/blob/master/care_nl_ica/dep_mat.py, +# licensed under the following MIT License: +# +# MIT License +# +# Copyright (c) 2022 Patrik Reizinger +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from typing import Union + +import numpy as np +import torch + + +def tensors_to_cpu_and_double(vars_: list[torch.Tensor]) -> list[torch.Tensor]: + """Convert a list of tensors to CPU and double precision. + + Args: + vars_: List of PyTorch tensors to convert + + Returns: + List of tensors converted to CPU and double precision + """ + cpu_vars = [] + for v in vars_: + if v.is_cuda: + v = v.to("cpu") + cpu_vars.append(v.double()) + return cpu_vars + + +def tensors_to_cuda(vars_: list[torch.Tensor], + cuda_device: str) -> list[torch.Tensor]: + """Convert a list of tensors to CUDA device. + + Args: + vars_: List of PyTorch tensors to convert + cuda_device: CUDA device to move tensors to + + Returns: + List of tensors moved to specified CUDA device + """ + cpu_vars = [] + for v in vars_: + if not v.is_cuda: + v = v.to(cuda_device) + cpu_vars.append(v) + return cpu_vars + + +def compute_jacobian( + model: torch.nn.Module, + input_vars: list[torch.Tensor], + mode: str = "autograd", + cuda_device: str = "cuda", + double_precision: bool = False, + convert_to_numpy: bool = True, + hybrid_solver: bool = False, +) -> Union[torch.Tensor, np.ndarray]: + """Compute the Jacobian matrix for a given model and input. + + This function computes the Jacobian matrix using PyTorch's autograd functionality. + It supports both CPU and CUDA computation, as well as single and double precision. + + Args: + model: PyTorch model to compute Jacobian for + input_vars: List of input tensors + mode: Computation mode, currently only "autograd" is supported + cuda_device: Device to use for CUDA computation + double_precision: If True, use double precision + convert_to_numpy: If True, convert output to numpy array + hybrid_solver: If True, concatenate multiple outputs along dimension 1 + + Returns: + Jacobian matrix as either PyTorch tensor or numpy array + """ + if double_precision: + model = model.to("cpu").double() + input_vars = tensors_to_cpu_and_double(input_vars) + if hybrid_solver: + output = model(*input_vars) + output_vars = torch.cat(output, dim=1).to("cpu").double() + else: + output_vars = model(*input_vars).to("cpu").double() + else: + model = model.to(cuda_device).float() + input_vars = tensors_to_cuda(input_vars, cuda_device=cuda_device) + + if hybrid_solver: + output = model(*input_vars) + output_vars = torch.cat(output, dim=1) + else: + output_vars = model(*input_vars) + + if mode == "autograd": + jacob = [] + for i in range(output_vars.shape[1]): + grads = torch.autograd.grad( + output_vars[:, i:i + 1], + input_vars, + retain_graph=True, + create_graph=False, + grad_outputs=torch.ones(output_vars[:, i:i + 1].shape).to( + output_vars.device), + ) + jacob.append(torch.cat(grads, dim=1)) + + jacobian = torch.stack(jacob, dim=1) + + jacobian = jacobian.detach().cpu() + + if convert_to_numpy: + jacobian = jacobian.numpy() + + return jacobian diff --git a/cebra/attribution/attribution_models.py b/cebra/attribution/attribution_models.py new file mode 100644 index 00000000..ddbc7a37 --- /dev/null +++ b/cebra/attribution/attribution_models.py @@ -0,0 +1,720 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import time + +import cvxpy as cp +import numpy as np +import scipy.linalg +import sklearn.metrics +import torch +import torch.nn as nn +import tqdm +from captum.attr import NeuronFeatureAblation +from captum.attr import NeuronGradient +from captum.attr import NeuronGradientShap +from captum.attr import NeuronIntegratedGradients + +import cebra +import cebra.attribution._jacobian +from cebra.attribution import register + + +@dataclasses.dataclass +class AttributionMap: + """Base class for computing attribution maps for CEBRA models. + + Args: + model: The trained CEBRA model to analyze + input_data: Input data tensor to compute attributions for + output_dimension: Output dimension to analyze. If ``None``, uses model's output dimension + num_samples: Number of samples to use for attribution. If ``None``, uses full dataset + seed: Random seed which is used to subsample the data. Only relevant if ``num_samples`` is not ``None``. + """ + + model: nn.Module + input_data: torch.Tensor + output_dimension: int = None + num_samples: int = None + seed: int = 9712341 + + def __post_init__(self): + if isinstance(self.model, cebra.models.ConvolutionalModelMixin): + data = cebra.data.TensorDataset(self.input_data, + continuous=torch.zeros( + len(self.input_data))) + data.configure_for(self.model) + offset = self.model.get_offset() + + #NOTE: explain, why do we do this again? + input_data = data[torch.arange(offset.left, + len(data) - offset.right + 1)].to( + self.input_data.device) + + # subsample the data + if self.num_samples is not None: + if self.num_samples > input_data.shape[0]: + raise ValueError( + f"You are using a bigger number of samples to " + f"subsample ({self.num_samples}) than the number " + f"of samples in the dataset ({input_data.shape[0]}).") + + random_generator = torch.Generator() + random_generator.manual_seed(self.seed) + num_elements = input_data.size(0) + random_indices = torch.randperm( + num_elements, generator=random_generator)[:self.num_samples] + input_data = input_data[random_indices] + + self.input_data = input_data + + def compute_attribution_map(self): + """Compute the attribution map for the model. + + Returns: + dict: Attribution maps and their variants + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError + + def compute_metrics(self, attribution_map, ground_truth_map): + """Compute metrics comparing attribution map to ground truth. + + This function computes various statistical metrics to compare the attribution values + between connected and non-connected neurons based on a ground truth connectivity map. + It separates the attribution values into two groups based on the binary ground truth, + and calculates summary statistics and differences between these groups. + + Args: + attribution_map: Computed attribution values representing the strength of connections + between neurons + ground_truth_map: Binary ground truth connectivity map where True indicates a + connected neuron and False indicates a non-connected neuron + + Returns: + dict: Dictionary containing the following metrics: + - max/mean/min_nonconnected: Statistics for non-connected neurons + - max/mean/min_connected: Statistics for connected neurons + - gap_max: Difference between max connected and max non-connected values + - gap_mean: Difference between mean connected and mean non-connected values + - gap_min: Difference between min connected and min non-connected values + - gap_minmax: Difference between min connected and max non-connected values + - max/min_jacobian: Global max/min values across all neurons + """ + assert np.issubdtype(ground_truth_map.dtype, bool) + connected_neurons = attribution_map[np.where(ground_truth_map)] + non_connected_neurons = attribution_map[np.where(~ground_truth_map)] + assert connected_neurons.size == ground_truth_map.sum() + assert non_connected_neurons.size == ground_truth_map.size - ground_truth_map.sum( + ) + assert connected_neurons.size + non_connected_neurons.size == attribution_map.size == ground_truth_map.size + + max_connected = np.max(connected_neurons) + mean_connected = np.mean(connected_neurons) + min_connected = np.min(connected_neurons) + + max_nonconnected = np.max(non_connected_neurons) + mean_nonconnected = np.mean(non_connected_neurons) + min_nonconnected = np.min(non_connected_neurons) + + metrics = { + 'max_nonconnected': max_nonconnected, + 'mean_nonconnected': mean_nonconnected, + 'min_nonconnected': min_nonconnected, + 'max_connected': max_connected, + 'mean_connected': mean_connected, + 'min_connected': min_connected, + 'gap_max': max_connected - max_nonconnected, + 'gap_mean': mean_connected - mean_nonconnected, + 'gap_min': min_connected - min_nonconnected, + 'gap_minmax': min_connected - max_nonconnected, + 'max_jacobian': np.max(attribution_map), + 'min_jacobian': np.min(attribution_map), + } + return metrics + + def compute_attribution_score(self, attribution_map, ground_truth_map): + """Compute ROC AUC score between attribution map and ground truth. + + Args: + attribution_map: Computed attribution values + ground_truth_map: Binary ground truth connectivity map + + Returns: + float: ROC AUC score + """ + assert attribution_map.shape == ground_truth_map.shape + assert np.issubdtype(ground_truth_map.dtype, bool) + fpr, tpr, _ = sklearn.metrics.roc_curve( # noqa: codespell:ignore fpr, tpr + ground_truth_map.flatten(), attribution_map.flatten()) + auc = sklearn.metrics.auc(fpr, tpr) # noqa: codespell:ignore fpr, tpr + return auc + + @staticmethod + def _check_moores_penrose_conditions( + matrix: np.ndarray, matrix_inverse: np.ndarray) -> np.ndarray: + """Check Moore-Penrose conditions for a single matrix pair. + + Args: + matrix: Input matrix + matrix_inverse: Putative pseudoinverse matrix + + Returns: + np.ndarray: Boolean array indicating which conditions are satisfied + """ + matrix_inverse = matrix_inverse.T + condition_1 = np.allclose(matrix @ matrix_inverse @ matrix, matrix) + condition_2 = np.allclose(matrix_inverse @ matrix @ matrix_inverse, + matrix_inverse) + condition_3 = np.allclose((matrix @ matrix_inverse).T, + matrix @ matrix_inverse) + condition_4 = np.allclose((matrix_inverse @ matrix).T, + matrix_inverse @ matrix) + + return np.array([condition_1, condition_2, condition_3, condition_4]) + + def check_moores_penrose_conditions( + self, jacobian: np.ndarray, + jacobian_pseudoinverse: np.ndarray) -> np.ndarray: + """Check Moore-Penrose conditions for Jacobian matrices. + + Args: + jacobian: Jacobian matrices of shape (num samples, output_dim, num_neurons) + jacobian_pseudoinverse: Pseudoinverse matrices of shape (num samples, num_neurons, output_dim) + + Returns: + Boolean array of shape (num samples, 4) indicating satisfied conditions + """ + # check the four conditions + conditions = np.zeros((jacobian.shape[0], 4)) + for i, (matrix, inverse_matrix) in enumerate( + zip(jacobian, jacobian_pseudoinverse)): + conditions[i] = self._check_moores_penrose_conditions( + matrix, inverse_matrix) + return conditions + + def _inverse(self, jacobian, method="lsq"): + """Compute inverse/pseudoinverse of Jacobian matrices. + + Args: + jacobian: Input Jacobian matrices + method: Inversion method ('lsq_cvxpy', 'lsq', or 'svd') + + Returns: + (Inverse matrices, computation time) + """ + # NOTE(stes): Before we used "np.linalg.pinv" here, which + # is numerically not stable for the Jacobian matrices we + # need to compute. + start_time = time.time() + Jfinv = np.zeros_like(jacobian) + if method == "lsq_cvxpy": + for i in tqdm(range(len(jacobian))): + Jfinv[i] = self._inverse_lsq_cvxpy(jacobian[i]).T + elif method == "lsq": + for i in range(len(jacobian)): + Jfinv[i] = self._inverse_lsq_scipy(jacobian[i]).T + elif method == "svd": + for i in range(len(jacobian)): + Jfinv[i] = self._inverse_svd(jacobian[i]).T + else: + raise NotImplementedError(f"Method {method} not implemented.") + end_time = time.time() + return Jfinv, end_time - start_time + + @staticmethod + def _inverse_lsq_cvxpy(matrix: np.ndarray, + solver: str = 'SCS') -> np.ndarray: + """Compute least squares inverse using CVXPY. + + Args: + matrix: Input matrix + solver: CVXPY solver to use + + Returns: + np.ndarray: Least squares inverse matrix + """ + + matrix_param = cp.Parameter((matrix.shape[0], matrix.shape[1])) + matrix_param.value = matrix + + identity = np.eye(matrix.shape[0]) + matrix_inverse = cp.Variable((matrix.shape[1], matrix.shape[0])) + # noqa: codespell + objective = cp.Minimize( + cp.norm(matrix @ matrix_inverse - identity, + "fro")) # noqa: codespell:ignore fro + prob = cp.Problem(objective) + prob.solve(verbose=False, solver=solver) + + return matrix_inverse.value + + @staticmethod + def _inverse_lsq_scipy(jacobian): + """Compute least squares inverse using scipy.linalg.lstsq. + + Args: + jacobian: Input Jacobian matrix + + Returns: + np.ndarray: Least squares inverse matrix + """ + return scipy.linalg.lstsq(jacobian, np.eye(jacobian.shape[0]))[0] + + @staticmethod + def _inverse_svd(jacobian): + """Compute pseudoinverse using SVD. + + Args: + jacobian: Input Jacobian matrix + + Returns: + np.ndarray: Pseudoinverse matrix + """ + return scipy.linalg.pinv(jacobian) + + def _reduce_attribution_map(self, attribution_maps): + """Reduce attribution maps by averaging across dimensions. + + Args: + attribution_maps: Dictionary of attribution maps to reduce + + Returns: + dict: Reduced attribution maps + """ + + def _reduce(full_jacobian): + if full_jacobian.ndim == 4: + jf_convabs = abs(full_jacobian).mean(-1) + jf = full_jacobian.mean(-1) + else: + jf_convabs = full_jacobian + jf = full_jacobian + return jf, jf_convabs + + result = {} + for key, value in attribution_maps.items(): + result[key], result[f'{key}-convabs'] = _reduce(value) + return result + + +@dataclasses.dataclass +@register("jacobian-based") +class JFMethodBased(AttributionMap): + """Compute the attribution map using the Jacobian of the model encoder.""" + + def _compute_jacobian(self, input_data): + return cebra.attribution._jacobian.compute_jacobian( + self.model, + input_vars=[input_data], + mode="autograd", + cuda_device=self.input_data.device, + double_precision=False, + convert_to_numpy=True, + hybrid_solver=False, + ) + + def compute_attribution_map(self): + + full_jacobian = self._compute_jacobian(self.input_data) + + result = {} + for key, value in self._reduce_attribution_map({ + 'jf': full_jacobian + }).items(): + result[key] = value + for method in ['lsq', 'svd']: + print(f"Computing inverse for {key} with method {method}") + result[f"{key}-inv-{method}"], result[ + f'time_inversion_{method}'] = self._inverse(value, + method=method) + # result[f"{key}-inv-{method}-conditions"] = self.check_moores_penrose_conditions(value, result[f"{key}-inv-{method}"]) + + return result + + +@dataclasses.dataclass +@register("jacobian-based-batched") +class JFMethodBasedBatched(JFMethodBased): + """Compute an attribution map based on the Jacobian using mini-batches. + + See also: + :py:class:`JFMethodBased` + """ + + def compute_attribution_map(self, batch_size=1024): + if batch_size > self.input_data.shape[0]: + raise ValueError( + f"Batch size ({batch_size}) is bigger than data ({self.input_data.shape[0]})" + ) + + input_data_batches = torch.split(self.input_data, batch_size) + full_jacobian = [] + for input_data_batch in input_data_batches: + jacobian_batch = self._compute_jacobian(input_data_batch) + full_jacobian.append(jacobian_batch) + full_jacobian = np.vstack(full_jacobian) + + result = {} + for key, value in self._reduce_attribution_map({ + 'jf': full_jacobian + }).items(): + result[key] = value + for method in ['lsq', 'svd']: + + result[f"{key}-inv-{method}"], result[ + f'time_inversion_{method}'] = self._inverse(value, + method=method) + + return result + + +@dataclasses.dataclass +@register("neuron-gradient") +class NeuronGradientMethod(AttributionMap): + """Compute the attribution map using the neuron gradient from Captum. + + Note: + This method is equivalent to Jacobian-based attributions, but + uses a different backend implementation. + """ + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronGradient(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, attribute_to_neuron_input=False): + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + attribute_to_neuron_input=attribute_to_neuron_input, + neuron_selector=s) + + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + + result = {} + for key, value in self._reduce_attribution_map({ + 'neuron-gradient': attribution_map + }).items(): + result[key] = value + + for method in ['lsq', 'svd']: + result[f"{key}-inv-{method}"], result[ + f'time_inversion_{method}'] = self._inverse(value, + method=method) + # result[f"{key}-inv-{method}-conditions"] = self.check_moores_penrose_conditions(value, result[f"{key}-inv-{method}"]) + + return result + + +@dataclasses.dataclass +@register("neuron-gradient-batched") +class NeuronGradientMethodBatched(NeuronGradientMethod): + """As :py:class:`NeuronGradientMethod`, but using mini-batches. + + See also: + :py:class:`NeuronGradientMethod` + """ + + def compute_attribution_map(self, + attribute_to_neuron_input=False, + batch_size=1024): + input_data_batches = torch.split(self.input_data, batch_size) + + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + attribute_to_neuron_input=attribute_to_neuron_input, + neuron_selector=s) + + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map({ + 'neuron-gradient': attribution_map, + #'neuron-gradient-invsvd': self._inverse_svd(attribution_map) + }) + + +@dataclasses.dataclass +@register("feature-ablation") +class FeatureAblationMethod(AttributionMap): + """Compute the attribution map using the feature ablation method from Captum.""" + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronFeatureAblation(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, + baselines=None, + feature_mask=None, + perturbations_per_eval=1, + attribute_to_neuron_input=False): + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + neuron_selector=s, + baselines=baselines, + perturbations_per_eval=perturbations_per_eval, + feature_mask=feature_mask, + attribute_to_neuron_input=attribute_to_neuron_input) + + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + return self._reduce_attribution_map( + {'feature-ablation': attribution_map}) + + +@dataclasses.dataclass +@register("feature-ablation-batched") +class FeatureAblationMethodBAtched(FeatureAblationMethod): + """As :py:class:`FeatureAblationMethod`, but using mini-batches. + + See also: + :py:class:`FeatureAblationMethod` + """ + + def compute_attribution_map(self, + baselines=None, + feature_mask=None, + perturbations_per_eval=1, + attribute_to_neuron_input=False, + batch_size=1024): + + input_data_batches = torch.split(self.input_data, batch_size) + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + neuron_selector=s, + baselines=baselines, + perturbations_per_eval=perturbations_per_eval, + feature_mask=feature_mask, + attribute_to_neuron_input=attribute_to_neuron_input) + + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map( + {'feature-ablation': attribution_map}) + + +@dataclasses.dataclass +@register("integrated-gradients") +class IntegratedGradientsMethod(AttributionMap): + """Compute the attribution map using the integrated gradients method from Captum.""" + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronIntegratedGradients(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, + n_steps=50, + method='gausslegendre', + internal_batch_size=None, + attribute_to_neuron_input=False, + baselines=None): + if internal_batch_size == "dataset": + internal_batch_size = len(self.input_data) + + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + neuron_selector=s, + n_steps=n_steps, + method=method, + internal_batch_size=internal_batch_size, + attribute_to_neuron_input=attribute_to_neuron_input, + baselines=baselines, + ) + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + return self._reduce_attribution_map( + {'integrated-gradients': attribution_map}) + + +@dataclasses.dataclass +@register("integrated-gradients-batched") +class IntegratedGradientsMethodBatched(IntegratedGradientsMethod): + """As :py:class:`IntegratedGradientsMethod`, but using mini-batches. + + See also: + :py:class:`IntegratedGradientsMethod` + """ + + def compute_attribution_map(self, + n_steps=50, + method='gausslegendre', + internal_batch_size=None, + attribute_to_neuron_input=False, + baselines=None, + batch_size=1024): + + input_data_batches = torch.split(self.input_data, batch_size) + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + if internal_batch_size == "dataset": + internal_batch_size = len(input_data_batch) + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + neuron_selector=s, + n_steps=n_steps, + method=method, + internal_batch_size=internal_batch_size, + attribute_to_neuron_input=attribute_to_neuron_input, + baselines=baselines, + ) + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map( + {'integrated-gradients': attribution_map}) + + +@dataclasses.dataclass +@register("neuron-gradient-shap") +class NeuronGradientShapMethod(AttributionMap): + """Compute the attribution map using the neuron gradient SHAP method from Captum.""" + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronGradientShap(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, + baselines: str, + n_samples=5, + stdevs=0.0, + attribute_to_neuron_input=False): + + if baselines == "zeros": + baselines = torch.zeros(size=(self.input_data.shape), + device=self.input_data.device) + elif baselines == "shuffle": + data = self.input_data.flatten() + data = data[torch.randperm(len(data))] + baselines = data.reshape(self.input_data.shape) + else: + raise NotImplementedError(f"Baseline {baselines} not implemented.") + + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + neuron_selector=s, + baselines=baselines, + n_samples=n_samples, + stdevs=stdevs, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + return self._reduce_attribution_map( + {'neuron-gradient-shap': attribution_map}) + + +@dataclasses.dataclass +@register("neuron-gradient-shap-batched") +class NeuronGradientShapMethodBatched(NeuronGradientShapMethod): + """As :py:class:`NeuronGradientShapMethod`, but using mini-batches. + + See also: + :py:class:`NeuronGradientShapMethod` + """ + + def compute_attribution_map(self, + baselines: str, + n_samples=5, + stdevs=0.0, + attribute_to_neuron_input=False, + batch_size=1024): + + if baselines == "zeros": + baselines = torch.zeros(size=(self.input_data.shape), + device=self.input_data.device) + elif baselines == "shuffle": + data = self.input_data.flatten() + data = data[torch.randperm(len(data))] + baselines = data.reshape(self.input_data.shape) + else: + raise NotImplementedError(f"Baseline {baselines} not implemented.") + + input_data_batches = torch.split(self.input_data, batch_size) + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + neuron_selector=s, + baselines=baselines, + n_samples=n_samples, + stdevs=stdevs, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map( + {'neuron-gradient-shap': attribution_map}) diff --git a/cebra/attribution/jacobian_attribution.py b/cebra/attribution/jacobian_attribution.py new file mode 100644 index 00000000..f8db8344 --- /dev/null +++ b/cebra/attribution/jacobian_attribution.py @@ -0,0 +1,95 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tools for computing attribution maps.""" + +from typing import Literal + +import numpy as np +import torch +from torch import nn + +import cebra.attribution._jacobian + +__all__ = ["get_attribution_map"] + + +def _prepare_inputs(inputs): + if not isinstance(inputs, torch.Tensor): + inputs = torch.from_numpy(inputs) + inputs.requires_grad_(True) + return inputs + + +def _prepare_model(model): + for p in model.parameters(): + p.requires_grad_(False) + return model + + +def get_attribution_map( + model: nn.Module, + input_data: torch.Tensor, + double_precision: bool = True, + convert_to_numpy: bool = True, + aggregate: Literal["mean", "sum", "max"] = "mean", + transform: Literal["none", "abs"] = "none", + hybrid_solver: bool = False, +): + """Estimate attribution maps using the Jacobian pseudo-inverse. + + The function estimates Jacobian matrices for each point in the model, + computes the pseudo-inverse (for every sample) and then aggregates + the resulting matrices to compute an attribution map. + + Args: + model: The neural network model for which to compute attributions. + input_data: Input tensor or numpy array to compute attributions for. + double_precision: If ``True``, use double precision for computation. + convert_to_numpy: If ``True``, convert the output to numpy arrays. + aggregate: Method to aggregate attribution values across samples. + Options are ``"mean"``, ``"sum"``, or ``"max"``. + transform: Transformation to apply to attribution values. + Options are ``"none"`` or ``"abs"``. + hybrid_solver: If ``True``, handle multi-objective models differently. + + Returns: + A tuple containing the Jacobian matrix of shape (num_samples, output_dim, input_dim) + and the pseudo-inverse of the Jacobian matrix. + + """ + assert aggregate in ["mean", "sum", "max"] + + input_data = _prepare_inputs(input_data) + model = _prepare_model(model) + + # compute jacobian CEBRA model + jf = cebra.attribution._jacobian.compute_jacobian( + model, + input_vars=[input_data], + mode="autograd", + double_precision=double_precision, + convert_to_numpy=convert_to_numpy, + hybrid_solver=hybrid_solver, + ) + + jhatg = np.linalg.pinv(jf) + return jf, jhatg diff --git a/cebra/data/__init__.py b/cebra/data/__init__.py index ec753f18..697801ed 100644 --- a/cebra/data/__init__.py +++ b/cebra/data/__init__.py @@ -46,10 +46,8 @@ # these imports will not be reordered by isort (see .isort.cfg) from cebra.data.base import * from cebra.data.datatypes import * - from cebra.data.single_session import * from cebra.data.multi_session import * - +from cebra.data.multiobjective import * from cebra.data.datasets import * - from cebra.data.helper import * diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index dbb2f1f5..24735f47 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -22,7 +22,7 @@ """Pre-defined datasets.""" import types -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, TYPE_CHECKING, Union import numpy as np import numpy.typing as npt @@ -30,8 +30,14 @@ import cebra.data as cebra_data import cebra.helper as cebra_helper +import cebra.io as cebra_io +from cebra.data.datatypes import Batch +from cebra.data.datatypes import BatchIndex from cebra.data.datatypes import Offset +if TYPE_CHECKING: + from cebra.models import Model + class TensorDataset(cebra_data.SingleSessionDataset): """Discrete and/or continuously indexed dataset based on torch/numpy arrays. @@ -295,3 +301,137 @@ def _apply(self, func): def _iter_property(self, attr): return (getattr(data, attr) for data in self.iter_sessions()) + + +# TODO(stes): This should be a single session dataset? +class DatasetxCEBRA(cebra_io.HasDevice): + """Dataset class for xCEBRA models. + + This class handles neural data and associated labels for xCEBRA models, providing + functionality for data loading and batch preparation. + + Attributes: + neural: Neural data as a torch.Tensor or numpy array + labels: Labels associated with the data + offset: Offset for the dataset + + Args: + neural: Neural data as a torch.Tensor or numpy array + device: Device to store the data on (default: "cpu") + **labels: Additional keyword arguments for labels associated with the data + """ + + def __init__( + self, + neural: Union[torch.Tensor, npt.NDArray], + device="cpu", + **labels, + ): + super().__init__(device) + self.neural = neural + self.labels = labels + self.offset = Offset(0, 1) + + @property + def input_dimension(self) -> int: + """Get the input dimension of the neural data. + + Returns: + The number of features in the neural data + """ + return self.neural.shape[1] + + def __len__(self): + """Get the length of the dataset. + + Returns: + Number of samples in the dataset + """ + return len(self.neural) + + def configure_for(self, model: "Model"): + """Configure the dataset offset for the provided model. + + Call this function before indexing the dataset. This sets the + :py:attr:`offset` attribute of the dataset. + + Args: + model: The model to configure the dataset for. + """ + self.offset = model.get_offset() + + def expand_index(self, index: torch.Tensor) -> torch.Tensor: + """Expand indices based on the configured offset. + + Args: + index: A one-dimensional tensor of type long containing indices + to select from the dataset. + + Returns: + An expanded index of shape ``(len(index), len(self.offset))`` where + the elements will be + ``expanded_index[i,j] = index[i] + j - self.offset.left`` for all ``j`` + in ``range(0, len(self.offset))``. + + Note: + Requires the :py:attr:`offset` to be set. + """ + offset = torch.arange(-self.offset.left, + self.offset.right, + device=index.device) + + index = torch.clamp(index, self.offset.left, + len(self) - self.offset.right) + + return index[:, None] + offset[None, :] + + def __getitem__(self, index): + """Get item(s) from the dataset at the specified index. + + Args: + index: Index or indices to retrieve + + Returns: + The neural data at the specified indices, with dimensions transposed + """ + index = self.expand_index(index) + return self.neural[index].transpose(2, 1) + + def load_batch_supervised(self, index: Batch, + labels_supervised) -> torch.Tensor: + """Load a batch for supervised learning. + + Args: + index: Batch indices for reference data + labels_supervised: Labels to load for supervised learning + + Returns: + Batch containing reference data and corresponding labels + """ + assert index.negative is None + assert index.positive is None + labels = [ + self.labels[label].to(self.device) for label in labels_supervised + ] + + return Batch( + reference=self[index.reference], + positive=[label[index.reference] for label in labels], + negative=None, + ) + + def load_batch_contrastive(self, index: BatchIndex) -> Batch: + """Load a batch for contrastive learning. + + Args: + index: BatchIndex containing reference, positive and negative indices + + Returns: + Batch containing reference, positive and negative samples + """ + assert isinstance(index.positive, list) + return Batch( + reference=self[index.reference], + positive=[self[idx] for idx in index.positive], + negative=self[index.negative], + ) diff --git a/cebra/data/multiobjective.py b/cebra/data/multiobjective.py new file mode 100644 index 00000000..f700d1c4 --- /dev/null +++ b/cebra/data/multiobjective.py @@ -0,0 +1,173 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import literate_dataclasses as dataclasses + +import cebra.data as cebra_data +import cebra.distributions +from cebra.data.datatypes import BatchIndex +from cebra.distributions.continuous import Prior + + +@dataclasses.dataclass +class MultiObjectiveLoader(cebra_data.Loader): + """Baseclass of Multiobjective Data Loader. Yields batches of the specified size from the given dataset object. + """ + dataset: int = dataclasses.field( + default=None, + doc="""A dataset instance specifying a ``__getitem__`` function.""", + ) + num_steps: int = dataclasses.field(default=None) + batch_size: int = dataclasses.field(default=None) + + def __post_init__(self): + super().__post_init__() + if self.batch_size > len(self.dataset.neural): + raise ValueError("Batch size can't be larger than data.") + self.prior = Prior(self.dataset.neural, device=self.device) + + def get_indices(self): + return NotImplementedError + + def __iter__(self): + return NotImplementedError + + def add_config(self, config): + raise NotImplementedError + + +@dataclasses.dataclass +class SupervisedMultiObjectiveLoader(MultiObjectiveLoader): + """Supervised Multiobjective data Loader. Yields batches of the specified size from the given dataset object. + """ + sampling_mode_supervised: str = dataclasses.field( + default="ref_shared", + doc="""Type of sampling performed, re whether reference are shared or not. + are shared. Options will be ref_shared, independent.""") + + def __post_init__(self): + super().__post_init__() + self.labels = [] + + def add_config(self, config): + self.labels.append(config['label']) + + def get_indices(self, num_samples: int): + if self.sampling_mode_supervised == "ref_shared": + reference_idx = self.prior.sample_prior(num_samples) + else: + raise ValueError( + f"Sampling mode {self.sampling_mode_supervised} is not implemented." + ) + + batch_index = BatchIndex( + reference=reference_idx, + positive=None, + negative=None, + ) + + return batch_index + + def __iter__(self): + for _ in range(len(self)): + index = self.get_indices(num_samples=self.batch_size) + yield self.dataset.load_batch_supervised(index, self.labels) + + +@dataclasses.dataclass +class ContrastiveMultiObjectiveLoader(MultiObjectiveLoader): + """Contrastive Multiobjective data Loader. Yields batches of the specified size from the given dataset object. + """ + + sampling_mode_contrastive: str = dataclasses.field( + default="refneg_shared", + doc= + """Type of sampling performed, re whether reference and negative samples + are shared. Options will be ref_shared, neg_shared and refneg_shared""" + ) + + def __post_init__(self): + super().__post_init__() + self.distributions = [] + + def add_config(self, config): + kwargs_distribution = config['kwargs'] + if config['distribution'] == "time": + distribution = cebra.distributions.TimeContrastive( + time_offset=kwargs_distribution['time_offset'], + num_samples=len(self.dataset.neural), + device=self.device, + ) + elif config['distribution'] == "time_delta": + distribution = cebra.distributions.TimedeltaDistribution( + continuous=self.dataset.labels[ + kwargs_distribution['label_name']], + time_delta=kwargs_distribution['time_delta'], + device=self.device) + elif config['distribution'] == "delta_normal": + distribution = cebra.distributions.DeltaNormalDistribution( + continuous=self.dataset.labels[ + kwargs_distribution['label_name']], + delta=kwargs_distribution['delta'], + device=self.device) + elif config['distribution'] == "delta_vmf": + distribution = cebra.distributions.DeltaVMFDistribution( + continuous=self.dataset.labels[ + kwargs_distribution['label_name']], + delta=kwargs_distribution['delta'], + device=self.device) + else: + raise NotImplementedError( + f"Distribution {config['distribution']} is not implemented yet." + ) + + self.distributions.append(distribution) + + def get_indices(self, num_samples: int): + """Sample and return the specified number of indices.""" + + if self.sampling_mode_contrastive == "refneg_shared": + ref_and_neg = self.prior.sample_prior(num_samples * 2) + reference_idx = ref_and_neg[:num_samples] + negative_idx = ref_and_neg[num_samples:] + + positives_idx = [] + for distribution in self.distributions: + idx = distribution.sample_conditional(reference_idx) + positives_idx.append(idx) + + batch_index = BatchIndex( + reference=reference_idx, + positive=positives_idx, + negative=negative_idx, + ) + else: + raise ValueError( + f"Sampling mode {self.sampling_mode_contrastive} is not implemented yet." + ) + + return batch_index + + def __iter__(self): + for _ in range(len(self)): + index = self.get_indices(num_samples=self.batch_size) + yield self.dataset.load_batch_contrastive(index) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7802b787..31d9b9d7 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -227,6 +227,13 @@ def _init_distribution(self): self.dataset.continuous_index, self.delta, device=self.device) + # TODO(stes): Add this distribution from internal xCEBRA codebase at a later point + # in time, currently not in use. + #elif self.conditional == "delta_vmf": + # self.distribution = cebra.distributions.DeltaVMFDistribution( + # self.dataset.continuous_index, + # self.delta, + # device=self.device) else: raise ValueError(self.conditional) @@ -334,6 +341,7 @@ class HybridDataLoader(cebra_data.Loader): """ conditional: str = dataclasses.field(default="time_delta") + time_distribution: str = dataclasses.field(default="time") time_offset: int = dataclasses.field(default=10) delta: float = dataclasses.field(default=0.1) @@ -351,17 +359,59 @@ def __post_init__(self): # e.g. integrating the FAISS dataloader back in. super().__post_init__() - if self.conditional != "time_delta": - raise NotImplementedError( - "Hybrid training is currently only implemented using the ``time_delta`` " - "continual distribution.") - - self.time_distribution = cebra.distributions.TimeContrastive( - time_offset=self.time_offset, - num_samples=len(self.dataset.neural), - device=self.device) - self.behavior_distribution = cebra.distributions.TimedeltaDistribution( - self.dataset.continuous_index, self.time_offset, device=self.device) + self._init_behavior_distribution() + self._init_time_distribution() + + def _init_behavior_distribution(self): + if self.conditional == "time": + self.behavior_distribution = cebra.distributions.TimeContrastive( + time_offset=self.time_offset, + num_samples=len(self.dataset.neural), + device=self.device, + ) + + if self.conditional == "time_delta": + self.behavior_distribution = cebra.distributions.TimedeltaDistribution( + self.dataset.continuous_index, + self.time_offset, + device=self.device) + + elif self.conditional == "delta_normal": + self.behavior_distribution = cebra.distributions.DeltaNormalDistribution( + self.dataset.continuous_index, self.delta, device=self.device) + + elif self.conditional == "time": + self.behavior_distribution = cebra.distributions.TimeContrastive( + time_offset=self.time_offset, + num_samples=len(self.dataset.neural), + device=self.device, + ) + + def _init_time_distribution(self): + + if self.time_distribution == "time": + self.time_distribution = cebra.distributions.TimeContrastive( + time_offset=self.time_offset, + num_samples=len(self.dataset.neural), + device=self.device, + ) + + elif self.time_distribution == "time_delta": + self.time_distribution = cebra.distributions.TimedeltaDistribution( + self.dataset.continuous_index, + self.time_offset, + device=self.device) + + elif self.time_distribution == "delta_normal": + self.time_distribution = cebra.distributions.DeltaNormalDistribution( + self.dataset.continuous_index, self.delta, device=self.device) + + # TODO(stes): Add this distribution from internal xCEBRA codebase at a later point + #elif self.time_distribution == "delta_vmf": + # self.time_distribution = cebra.distributions.DeltaVMFDistribution( + # self.dataset.continuous_index, self.delta, device=self.device) + else: + raise ValueError def get_indices(self, num_samples: int) -> BatchIndex: """Samples indices for reference, positive and negative examples. diff --git a/cebra/models/__init__.py b/cebra/models/__init__.py index 4dfad333..2d170e24 100644 --- a/cebra/models/__init__.py +++ b/cebra/models/__init__.py @@ -36,5 +36,7 @@ from cebra.models.multiobjective import * from cebra.models.layers import * from cebra.models.criterions import * +from cebra.models.multicriterions import * +from cebra.models.jacobian_regularizer import * cebra.registry.add_docstring(__name__) diff --git a/cebra/models/jacobian_regularizer.py b/cebra/models/jacobian_regularizer.py new file mode 100644 index 00000000..a909a31b --- /dev/null +++ b/cebra/models/jacobian_regularizer.py @@ -0,0 +1,148 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# This file contains the PyTorch implementation of Jacobian regularization described in [1]. +# Judy Hoffman, Daniel A. Roberts, and Sho Yaida, +# "Robust Learning with Jacobian Regularization," 2019. +# [arxiv:1908.02729](https://arxiv.org/abs/1908.02729) +# +# Adapted from https://github.com/facebookresearch/jacobian_regularizer/blob/main/jacobian/jacobian.py +# licensed under the following MIT License: +# +# MIT License +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +"""Jacobian Regularization for CEBRA. + +This implementation is adapted from the Jacobian regularization described in [1]_. + +.. [1] Judy Hoffman, Daniel A. Roberts, and Sho Yaida, + "Robust Learning with Jacobian Regularization," 2019. + `arxiv:1908.02729 `_ +""" + +from __future__ import division + +import torch +import torch.nn as nn + + +class JacobianReg(nn.Module): + """Loss criterion that computes the trace of the square of the Jacobian. + + Args: + n: Determines the number of random projections. If n=-1, then it is set to the dimension + of the output space and projection is non-random and orthonormal, yielding the exact + result. For any reasonable batch size, the default (n=1) should be sufficient. + |Default:| ``1`` + """ + + def __init__(self, n: int = 1): + assert n == -1 or n > 0 + self.n = n + super(JacobianReg, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes (1/2) tr \\|dy/dx\\|^2. + + Args: + x: Input tensor + y: Output tensor + + Returns: + The computed regularization term + """ + B, C = y.shape + if self.n == -1: + num_proj = C + else: + num_proj = self.n + J2 = 0 + for ii in range(num_proj): + if self.n == -1: + # orthonormal vector, sequentially spanned + v = torch.zeros(B, C) + v[:, ii] = 1 + else: + # random properly-normalized vector for each sample + v = self._random_vector(C=C, B=B) + if x.is_cuda: + v = v.cuda() + Jv = self._jacobian_vector_product(y, x, v, create_graph=True) + J2 += C * torch.norm(Jv)**2 / (num_proj * B) + R = (1 / 2) * J2 + return R + + def _random_vector(self, C: int, B: int) -> torch.Tensor: + """Creates a random vector of dimension C with a norm of C^(1/2). + + This is needed for the projection formula to work. + + Args: + C: Output dimension + B: Batch size + + Returns: + A random normalized vector + """ + if C == 1: + return torch.ones(B) + v = torch.randn(B, C) + arxilirary_zero = torch.zeros(B, C) + vnorm = torch.norm(v, 2, 1, True) + v = torch.addcdiv(arxilirary_zero, 1.0, v, vnorm) + return v + + def _jacobian_vector_product(self, + y: torch.Tensor, + x: torch.Tensor, + v: torch.Tensor, + create_graph: bool = False) -> torch.Tensor: + """Produce jacobian-vector product dy/dx dot v. + + Args: + y: Output tensor + x: Input tensor + v: Vector to compute product with + create_graph: If True, graph of the derivative will be constructed, allowing + to compute higher order derivative products. |Default:| ``False`` + + Returns: + The Jacobian-vector product + + Note: + If you want to differentiate the result, you need to make create_graph=True + """ + flat_y = y.reshape(-1) + flat_v = v.reshape(-1) + grad_x, = torch.autograd.grad(flat_y, + x, + flat_v, + retain_graph=True, + create_graph=create_graph) + return grad_x diff --git a/cebra/models/layers.py b/cebra/models/layers.py index 7c1c36e8..e8b8175e 100644 --- a/cebra/models/layers.py +++ b/cebra/models/layers.py @@ -97,3 +97,25 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: connect = self.layer(inp) downsampled = F.interpolate(inp, scale_factor=1 / self.downsample) return torch.cat([connect, downsampled[..., :connect.size(-1)]], dim=1) + + +class _SkipLinear(nn.Module): + """Add a skip connection to a linear module + Args: + module (torch.nn.Module): Module to add to the bottleneck + """ + + def __init__(self, module): + super().__init__() + self.module = module + assert isinstance(self.module, nn.Linear) + padding_size = self.module.out_features - self.module.in_features + self.padding_size = padding_size + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + """Compute forward pass through the skip connection. + """ + inp_padded = F.pad(inp, (0, self.padding_size), + mode='constant', + value=0) + return inp_padded + self.module(inp) diff --git a/cebra/models/model.py b/cebra/models/model.py index 7631ba86..a74b0229 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -29,6 +29,7 @@ import cebra.data import cebra.data.datatypes import cebra.models.layers as cebra_layers +from cebra.models import parametrize from cebra.models import register @@ -780,3 +781,261 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): def get_offset(self) -> cebra.data.datatypes.Offset: """See `:py:meth:Model.get_offset`""" return cebra.data.Offset(18, 18) + + +@register("offset15-model") +class Offset15Model(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 15 sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + nn.Conv1d(num_units, num_output, 2), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(7, 8) + + +@register("offset20-model") +class Offset20Model(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 15 sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(10, 10) + + +@register("offset10-model-mse-tanh") +class Offset10Model(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 10 sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=False): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + nn.Conv1d(num_units, num_output, 3), + nn.Tanh(), # Added tanh activation function + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(5, 5) + + +@register("offset1-model-mse-tanh") +class Offset0ModelMSETanH(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=False): + super().__init__( + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear( + num_neurons, + num_output * 30, + ), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 30), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 10), + nn.GELU(), + nn.Linear(int(num_output * 10), num_output), + nn.Tanh(), # Added tanh activation function + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) + + +@parametrize("offset1-model-mse-clip-{clip_min}-{clip_max}", + clip_min=(1000, 100, 50, 25, 20, 15, 10, 5, 1), + clip_max=(1000, 100, 50, 25, 20, 15, 10, 5, 1)) +class Offset0ModelMSEClip(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, + num_neurons, + num_units, + num_output, + clip_min=-1, + clip_max=1, + normalize=False): + super().__init__( + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear( + num_neurons, + num_output * 30, + ), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 30), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 10), + nn.GELU(), + nn.Linear(int(num_output * 10), num_output), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + self.clamp = nn.Hardtanh(-clip_min, clip_max) + + def forward(self, inputs): + outputs = super().forward(inputs) + outputs = self.clamp(outputs) + return outputs + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) + + +@parametrize("offset1-model-mse-v2-{n_intermediate_layers}layers{tanh}", + n_intermediate_layers=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + tanh=("-tanh", "")) +class Offset0ModelMSETanHv2(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, + num_neurons, + num_units, + num_output, + tanh="", + n_intermediate_layers=1, + normalize=False): + if num_units < 2: + raise ValueError( + f"Number of hidden units needs to be at least 2, but got {num_units}." + ) + + intermediate_layers = [ + nn.Linear(num_units, num_units), + nn.GELU(), + ] * n_intermediate_layers + + layers = [ + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear( + num_neurons, + num_units, + ), + nn.GELU(), + *intermediate_layers, + nn.Linear(num_units, int(num_units // 2)), + nn.GELU(), + nn.Linear(int(num_units // 2), num_output), + ] + + if tanh == "-tanh": + layers += [nn.Tanh()] + + super().__init__( + *layers, + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) + + +@parametrize("offset1-model-mse-resnet-{n_intermediate_layers}layers{tanh}", + n_intermediate_layers=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + tanh=("-tanh", "")) +class Offset0ModelResNetTanH(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, + num_neurons, + num_units, + num_output, + tanh="", + n_intermediate_layers=1, + normalize=False): + if num_units < 2: + raise ValueError( + f"Number of hidden units needs to be at least 2, but got {num_units}." + ) + + intermediate_layers = [ + cebra_layers._SkipLinear(nn.Linear(num_units, num_units)), + nn.GELU(), + ] * n_intermediate_layers + + layers = [ + nn.Flatten(start_dim=1, end_dim=-1), + cebra_layers._SkipLinear(nn.Linear( + num_neurons, + num_units, + )), + nn.GELU(), + *intermediate_layers, + cebra_layers._SkipLinear(nn.Linear(num_units, int(num_units // 2))), + nn.GELU(), + nn.Linear(int(num_units // 2), num_output), + ] + + if tanh == "-tanh": + layers += [nn.Tanh()] + + super().__init__( + *layers, + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) diff --git a/cebra/models/multicriterions.py b/cebra/models/multicriterions.py new file mode 100644 index 00000000..2b02fc37 --- /dev/null +++ b/cebra/models/multicriterions.py @@ -0,0 +1,154 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Support for training CEBRA with multiple criteria. + +.. note:: + This module was introduced in CEBRA 0.6.0. + +""" +from typing import Tuple + +import torch +from torch import nn + +from cebra.data.datatypes import Batch + + +class MultiCriterions(nn.Module): + """A module for handling multiple loss functions with different criteria. + + This module allows combining multiple loss functions, each operating on specific + slices of the input data. It supports both supervised and contrastive learning modes. + + Args: + losses: A list of dictionaries containing loss configurations. Each dictionary should have: + - 'indices': Tuple of (start, end) indices for the data slice + - 'supervised_loss': Dict with loss config for supervised mode + - 'contrastive_loss': Dict with loss config for contrastive mode + Loss configs should contain: + - 'name': Name of the loss function + - 'kwargs': Optional parameters for the loss function + mode: Either "supervised" or "contrastive" to specify the training mode + + The loss functions can be from torch.nn or custom implementations from cebra.models.criterions. + Each criterion is applied to its corresponding slice of the input data during forward pass. + + Example: + >>> import torch + >>> from cebra.data.datatypes import Batch + >>> # Define loss configurations for a hybrid model with both contrastive and supervised losses + >>> losses = [ + ... { + ... 'indices': (0, 10), # First 10 dimensions + ... 'contrastive_loss': { + ... 'name': 'InfoNCE', # Using CEBRA's InfoNCE loss + ... 'kwargs': {'temperature': 1.0} + ... }, + ... 'supervised_loss': { + ... 'name': 'nn.MSELoss', # Using PyTorch's MSE loss + ... 'kwargs': {} + ... } + ... }, + ... { + ... 'indices': (10, 20), # Next 10 dimensions + ... 'contrastive_loss': { + ... 'name': 'InfoNCE', # Using CEBRA's InfoNCE loss + ... 'kwargs': {'temperature': 0.5} + ... }, + ... 'supervised_loss': { + ... 'name': 'nn.L1Loss', # Using PyTorch's L1 loss + ... 'kwargs': {} + ... } + ... } + ... ] + >>> # Create sample predictions (2 batches of 32 samples each with 10 features) + >>> ref1 = torch.randn(32, 10) + >>> pos1 = torch.randn(32, 10) + >>> neg1 = torch.randn(32, 10) + >>> ref2 = torch.randn(32, 10) + >>> pos2 = torch.randn(32, 10) + >>> neg2 = torch.randn(32, 10) + >>> predictions = ( + ... Batch(reference=ref1, positive=pos1, negative=neg1), + ... Batch(reference=ref2, positive=pos2, negative=neg2) + ... ) + >>> # Create multi-criterion module in contrastive mode + >>> multi_loss = MultiCriterions(losses, mode="contrastive") + >>> # Forward pass with multiple predictions + >>> losses = multi_loss(predictions) # Returns list of loss values + >>> assert len(losses) == 2 # One loss per criterion + """ + + def __init__(self, losses, mode): + super(MultiCriterions, self).__init__() + self.criterions = nn.ModuleList() + self.slices = [] + + for loss_info in losses: + slice_indices = loss_info['indices'] + + if mode == "supervised": + loss = loss_info['supervised_loss'] + elif mode == "contrastive": + loss = loss_info['contrastive_loss'] + else: + raise NotImplementedError + + loss_name = loss['name'] + loss_kwargs = loss.get('kwargs', {}) + + if loss_name.startswith("nn"): + name = loss_name.split(".")[-1] + criterion = getattr(torch.nn, name, None) + else: + import cebra.models + criterion = getattr(cebra.models.criterions, loss_name, None) + + if criterion is None: + raise ValueError(f"Loss {loss_name} not found.") + else: + criterion = criterion(**loss_kwargs) + + self.criterions.append(criterion) + self.slices.append(slice(*slice_indices)) + assert len(self.criterions) == len(self.slices) + + def forward(self, predictions: Tuple[Batch]): + + losses = [] + + for criterion, prediction in zip(self.criterions, predictions): + + if prediction.negative is None: + # supervised + #reference: data, positive: label + loss = criterion(prediction.reference, prediction.positive) + else: + #contrastive + loss, pos, neg = criterion(prediction.reference, + prediction.positive, + prediction.negative) + + losses.append(loss) + + assert len(self.criterions) == len(predictions) == len(losses) + return losses diff --git a/cebra/models/multiobjective.py b/cebra/models/multiobjective.py index d9393fdc..5dc4d247 100644 --- a/cebra/models/multiobjective.py +++ b/cebra/models/multiobjective.py @@ -19,19 +19,80 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Wrappers for using models with multiobjective solvers. - -.. note:: - - Experimental as of Nov 06, 2022. -""" - -from typing import Tuple +import itertools +from typing import List, Tuple import torch from torch import nn import cebra.models +import cebra.models.model as cebra_models_base + + +def create_multiobjective_model(module, + **kwargs) -> "SubspaceMultiobjectiveModel": + assert isinstance(module, cebra_models_base.Model) + if isinstance(module, cebra.models.ConvolutionalModelMixin): + return SubspaceMultiobjectiveConvolutionalModel(module=module, **kwargs) + else: + return SubspaceMultiobjectiveModel(module=module, **kwargs) + + +def check_slices_for_gaps(slice_list): + slice_list = sorted(slice_list, key=lambda s: s.start) + for i in range(1, len(slice_list)): + if slice_list[i - 1].stop < slice_list[i].start: + raise ValueError( + f"There is a gap in the slices {slice_list[i-1]} and {slice_list[i]}" + ) + + +def check_overlapping_feature_ranges(slice_list): + for slice1, slice2 in itertools.combinations(slice_list, 2): + if slice1.start < slice2.stop and slice1.stop > slice2.start: + return True + return False + + +def compute_renormalize_ranges(feature_ranges, sort=True): + + max_slice_dim = max(s.stop for s in feature_ranges) + min_slice_dim = min(s.start for s in feature_ranges) + full_emb_slice = slice(min_slice_dim, max_slice_dim) + + n_full_emb_slices = sum(1 for s in feature_ranges if s == full_emb_slice) + + if n_full_emb_slices > 1: + raise ValueError( + "There are more than one slice that cover the full embedding.") + + if n_full_emb_slices == 0: + raise ValueError( + "There are overlapping slices but none of them cover the full embedding." + ) + + rest_of_slices = [s for s in feature_ranges if s != full_emb_slice] + max_slice_dim_rest = max(s.stop for s in rest_of_slices) + min_slice_dim_rest = min(s.start for s in rest_of_slices) + + remaining_slices = [] + if full_emb_slice.start < min_slice_dim_rest: + remaining_slices.append(slice(full_emb_slice.start, min_slice_dim_rest)) + + if full_emb_slice.stop > max_slice_dim_rest: + remaining_slices.append(slice(max_slice_dim_rest, full_emb_slice.stop)) + + if len(remaining_slices) == 0: + raise ValueError( + "The behavior slices and the time slices coincide completely.") + + final_slices = remaining_slices + rest_of_slices + + if sort: + final_slices = sorted(final_slices, key=lambda s: s.start) + + check_slices_for_gaps(final_slices) + return final_slices class _Norm(nn.Module): @@ -68,6 +129,13 @@ class MultiobjectiveModel(nn.Module): TODO: - Update nn.Module type annotation for ``module`` to cebra.models.Model + + Note: + This model will be deprecated in a future version. Please use the functionality in + :py:mod:`cebra.models.multiobjective` instead, which provides more versatile + multi-objective training capabilities. Instantiation of this model will raise a + deprecation warning. The new model is :py:class:`cebra.models.multiobjective.SubspaceMultiobjectiveModel` + which allows for unlimited subspaces and better configuration of the feature ranges. """ class Mode: @@ -178,3 +246,122 @@ def forward(self, inputs): if self.renormalize: outputs = (self._norm(output) for output in outputs) return tuple(outputs) + + +class SubspaceMultiobjectiveModel(nn.Module): + """Wrapper around contrastive learning models to all training with multiple objectives + + Multi-objective training splits the last layer's feature representation into multiple + chunks, which are then used for individual training objectives. + + Args: + module: The module to wrap + dimensions: A tuple of dimension values to extract from the model's feature embedding. + renormalize: If True, the individual feature slices will be re-normalized before + getting returned---this option only makes sense in conjunction with a loss based + on the cosine distance or dot product. + TODO: + - Update nn.Module type annotation for ``module`` to cebra.models.Model + """ + + def __init__(self, + module: nn.Module, + feature_ranges: List[slice], + renormalize: bool, + split_outputs: bool = True): + super().__init__() + + if not isinstance(module, cebra.models.Model): + raise ValueError("Can only wrap models that are subclassing the " + "cebra.models.Model abstract base class. " + f"Got a model of type {type(module)}.") + + self.module = module + self.renormalize = renormalize + self._norm = _Norm() + self.feature_ranges = feature_ranges + self.split_outputs = split_outputs + + max_slice_dim = max(s.stop for s in self.feature_ranges) + min_slice_dim = min(s.start for s in self.feature_ranges) + if min_slice_dim != 0: + raise ValueError( + f"The first slice should start at 0, but it starts at {min_slice_dim}." + ) + + if max_slice_dim != self.num_output: + raise ValueError( + f"The dimension of output {self.num_output} is different than the highest dimension of the slices ({max_slice_dim})." + f"The output dimension and slice dimension need to have the same dimension." + ) + + check_slices_for_gaps(self.feature_ranges) + + if check_overlapping_feature_ranges(self.feature_ranges): + print("Computing renormalized ranges...") + self.renormalize_ranges = compute_renormalize_ranges( + self.feature_ranges, sort=True) + print("New ranges:", self.renormalize_ranges) + + def set_split_outputs(self, val): + assert isinstance(val, bool) + self.split_outputs = val + + @property + def get_offset(self): + """See :py:meth:`cebra.models.model.Model.get_offset`.""" + return self.module.get_offset + + @property + def num_output(self): + """See :py:attr:`cebra.models.model.Model.num_output`.""" + return self.module.num_output + + def forward(self, inputs): + """Compute multiple embeddings for a single signal input. + + Args: + inputs: The input tensor + + Returns: + A tuple of tensors which are sliced according to `self.feature_ranges` + if `renormalize` is set to true, each of the tensors will be normalized + across the first (feature) dimension. + """ + + output = self.module(inputs) + + if (not self.renormalize) and (not self.split_outputs): + return output + + if self.renormalize: + if hasattr(self, "renormalize_ranges"): + if not all(self.renormalize_ranges[i].start <= + self.renormalize_ranges[i + 1].start + for i in range(len(self.renormalize_ranges) - 1)): + raise ValueError( + "The renormalize_ranges must be sorted by start index.") + + output = [ + self._norm(output[:, slice_features]) + for slice_features in self.renormalize_ranges + ] + else: + output = [ + self._norm(output[:, slice_features]) + for slice_features in self.feature_ranges + ] + + output = torch.cat(output, dim=1) + + if self.split_outputs: + return tuple(output[:, slice_features] + for slice_features in self.feature_ranges) + else: + assert isinstance(output, torch.Tensor) + return output + + +class SubspaceMultiobjectiveConvolutionalModel( + SubspaceMultiobjectiveModel, cebra_models_base.ConvolutionalModelMixin): + pass diff --git a/cebra/solver/__init__.py b/cebra/solver/__init__.py index 12ad2f06..965c16c8 100644 --- a/cebra/solver/__init__.py +++ b/cebra/solver/__init__.py @@ -37,6 +37,9 @@ # pylint: disable=wrong-import-position from cebra.solver.base import * from cebra.solver.multi_session import * +from cebra.solver.multiobjective import * +from cebra.solver.regularized import * +from cebra.solver.schedulers import * from cebra.solver.single_session import * from cebra.solver.supervised import * diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 14a22c68..992f4dae 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -32,6 +32,7 @@ import abc import os +import warnings from typing import Callable, Dict, List, Literal, Optional import literate_dataclasses as dataclasses @@ -367,11 +368,19 @@ class MultiobjectiveSolver(Solver): for time contrastive learning. renormalize_features: If ``True``, normalize the behavior and time contrastive features individually before computing similarity scores. + ignore_deprecation_warning: If ``True``, suppress the deprecation warning. + + Note: + This solver will be deprecated in a future version. Please use the functionality in + :py:mod:`cebra.solver.multiobjective` instead, which provides more versatile + multi-objective training capabilities. Instantiation of this solver will raise a + deprecation warning. """ num_behavior_features: int = 3 renormalize_features: bool = False output_mode: Literal["overlapping", "separate"] = "overlapping" + ignore_deprecation_warning: bool = False @property def num_time_features(self): @@ -383,6 +392,13 @@ def num_total_features(self): def __post_init__(self): super().__post_init__() + if not self.ignore_deprecation_warning: + warnings.warn( + "MultiobjectiveSolver is deprecated since CEBRA 0.6.0 and will be removed in a future version. " + "Use the new functionality in cebra.solver.multiobjective instead, which is more versatile. " + "If you see this warning when using the scikit-learn interface, no action is required.", + DeprecationWarning, + stacklevel=2) self._check_dimensions() self.model = cebra.models.MultiobjectiveModel( self.model, diff --git a/cebra/solver/multiobjective.py b/cebra/solver/multiobjective.py new file mode 100644 index 00000000..d4aa187d --- /dev/null +++ b/cebra/solver/multiobjective.py @@ -0,0 +1,527 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Multiobjective contrastive learning. + +Starting in CEBRA 0.6.0, we have added support for subspace contrastive learning. +This is a method for training models that are able to learn multiple subspaces of the +feature space simultaneously. + +Subspace contrastive learning requires to use specialized models and criterions. +This module specifies a test of classes required for training CEBRA models with multiple objectives. +The objectives are defined by the wrapper class :py:class:`cebra.models.multicriterions.MultiCriterions`. + +Two solvers are currently implemented: + +- :py:class:`cebra.solver.multiobjective.ContrastiveMultiobjectiveSolverxCEBRA` +- :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA` + +See Also: + :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA` + :py:class:`cebra.solver.multiobjective.MultiObjectiveConfig` + :py:class:`cebra.models.multicriterions.MultiCriterions` +""" + +import logging +import time +import warnings +from typing import Callable, Dict, List, Optional, Tuple + +import literate_dataclasses as dataclasses +import numpy as np +import torch + +import cebra +import cebra.data +import cebra.io +import cebra.models +from cebra.solver import register +from cebra.solver.base import Solver +from cebra.solver.schedulers import Scheduler +from cebra.solver.util import Meter + + +class MultiObjectiveConfig: + """Configuration class for setting up multi-objective learning with Cebra. + + + + Args: + loader: Data loader used for configurations. + """ + + def __init__(self, loader): + self.loader = loader + self.total_info = [] + self.current_info = {} + + def _check_overwriting_key(self, key): + if key in self.current_info: + warnings.warn( + "Configuration key already exists. Overwriting existing value. " + "If you don't want to overwrite you should call push() before.") + + def _check_pushed_status(self): + if "slice" not in self.current_info: + raise RuntimeError( + "Slice configuration is missing. Add it before pushing it.") + if "distributions" not in self.current_info: + raise RuntimeError( + "Distributions configuration is missing. Add it before pushing it." + ) + if "losses" not in self.current_info: + raise RuntimeError( + "Losses configuration is missing. Add it before pushing it.") + + def set_slice(self, start, end): + """Select the index range of the embedding. + + The configured loss will be applied to the ``start:end`` slice of the + embedding space. Make sure the selected dimensionality is appropriate + for the chosen loss function and distribution. + """ + self._check_overwriting_key("slice") + self.current_info['slice'] = (start, end) + + def set_loss(self, loss_name, **kwargs): + """Select the loss function to apply. + + Select a valid loss function from :py:mod:`cebra.models.criterions`. + Common choices are: + + - `FixedEuclideanInfoNCE` + - `FixedCosineInfoNCE` + + which can be passed as string values to ``loss_name``. The loss + will be applied to the range specified with ``set_slice``. + """ + self._check_overwriting_key("losses") + self.current_info["losses"] = {"name": loss_name, "kwargs": kwargs} + + def set_distribution(self, distribution_name, **kwargs): + """Select the distribution to sample from. + + The loss function specified in ``set_loss`` is applied to positive + and negative pairs sampled from the specified distribution. + """ + self._check_overwriting_key("distributions") + self.current_info["distributions"] = { + "name": distribution_name, + "kwargs": kwargs + } + + def push(self): + """Add a slice/loss/distribution setting to the config. + + After calling all of ``set_slice``, ``set_loss``, ``set_distribution``, + add this group to the config by calling this function. + + Once all configuration parts are pushed, call ``finalize`` to finish + the configuration. + """ + self._check_pushed_status() + print(f"Adding configuration for slice: {self.current_info['slice']}") + self.total_info.append(self.current_info) + self.current_info = {} + + def finalize(self): + """Finalize the multiobjective configuration.""" + self.losses = [] + self.feature_ranges = [] + self.feature_ranges_tuple = [] + + for info in self.total_info: + self._process_info(info) + + if len(set(self.feature_ranges_tuple)) != len( + self.feature_ranges_tuple): + raise RuntimeError( + f"Feature ranges are not unique. Please check again and remove the duplicates. " + f"Feature ranges: {self.feature_ranges_tuple}") + + print("Creating MultiCriterion") + self.criterion = cebra.models.MultiCriterions(losses=self.losses, + mode="contrastive") + + def _process_info(self, info): + """ + Processes individual configuration info and updates the losses and feature ranges. + + Args: + info (dict): The configuration info to process. + """ + slice_info = info["slice"] + losses_info = info["losses"] + distributions_info = info["distributions"] + + self.losses.append( + dict(indices=(slice_info[0], slice_info[1]), + contrastive_loss=dict(name=losses_info['name'], + kwargs=losses_info['kwargs']))) + + self.feature_ranges.append(slice(slice_info[0], slice_info[1])) + self.feature_ranges_tuple.append((slice_info[0], slice_info[1])) + + print(f"Adding distribution of slice: {slice_info}") + self.loader.add_config( + dict(distribution=distributions_info["name"], + kwargs=distributions_info["kwargs"])) + + +@dataclasses.dataclass +class MultiobjectiveSolverBase(Solver): + + feature_ranges: List[slice] = None + renormalize: bool = None + log: Dict[Tuple, + List[float]] = dataclasses.field(default_factory=lambda: ({})) + use_sam: bool = False + regularizer: torch.nn.Module = None + metadata: Dict = dataclasses.field(default_factory=lambda: ({ + "timestamp": None, + "batches_seen": None, + })) + + def __post_init__(self): + super().__post_init__() + + self.model = cebra.models.create_multiobjective_model( + module=self.model, + feature_ranges=self.feature_ranges, + renormalize=self.renormalize, + ) + + def fit(self, + loader: cebra.data.Loader, + valid_loader: cebra.data.Loader = None, + *, + valid_frequency: int = None, + log_frequency: int = None, + save_hook: Callable[[int, "Solver"], None] = None, + scheduler_regularizer: "Scheduler" = None, + scheduler_loss: "Scheduler" = None, + logger: logging.Logger = None): + """Train model for the specified number of steps. + + Args: + loader: Data loader, which is an iterator over `cebra.data.Batch` instances. + Each batch contains reference, positive and negative input samples. + valid_loader: Data loader used for validation of the model. + valid_frequency: The frequency for running validation on the ``valid_loader`` instance. + logdir: The logging directory for writing model checkpoints. The checkpoints + can be read again using the `solver.load` function, or manually via loading the + state dict. + save_hook: callback. It will be called when we run validation. + log_frequency: how frequent we log things. + logger: logger to log progress. None by default. + + """ + + def _run_validation(): + stats_val = self.validation(valid_loader, logger=logger) + if save_hook is not None: + save_hook(solver=self, step=num_steps) + return stats_val + + self.to(loader.device) + + iterator = self._get_loader(loader, + logger=logger, + log_frequency=log_frequency) + self.model.train() + for num_steps, batch in iterator: + weights_regularizer = None + if scheduler_regularizer is not None: + weights_regularizer = scheduler_regularizer.get_weights( + step=num_steps) + # NOTE(stes): Both SAM and Jacobian regularization is not yet supported. + # For this, we need to re-implement the closure logic below (right now, + # the closure function applies the non-regularized loss in the second + # step, it is unclear if that is the correct behavior. + assert not self.use_sam + + weights_loss = None + if scheduler_loss is not None: + weights_loss = scheduler_loss.get_weights() + + stats = self.step(batch, + weights_regularizer=weights_regularizer, + weights_loss=weights_loss) + + self._update_metadata(num_steps) + iterator.set_description(stats) + run_validation = (valid_loader + is not None) and (num_steps % valid_frequency + == 0) + if run_validation: + _run_validation() + + #TODO + #_run_validation() + + def _get_loader(self, loader, **kwargs): + return super()._get_loader(loader) + + def _update_metadata(self, num_steps): + self.metadata["timestamp"] = time.time() + self.metadata["batches_seen"] = num_steps + + def compute_regularizer(self, predictions, inputs): + regularizer = [] + for prediction in predictions: + R = self.regularizer(inputs, prediction.reference) + regularizer.append(R) + + return regularizer + + def create_closure(self, batch, weights_loss): + + def inner_closure(): + predictions = self._inference(batch) + losses = self.criterion(predictions) + + if weights_loss is not None: + assert len(weights_loss) == len( + losses + ), "Number of weights should match the number of losses" + losses = [ + weight * loss for weight, loss in zip(weights_loss, losses) + ] + + loss = sum(losses) + loss.backward() + return loss + + return inner_closure + + def step(self, + batch: cebra.data.Batch, + weights_loss: Optional[List[float]] = None, + weights_regularizer: Optional[List[float]] = None) -> dict: + """Perform a single gradient update with multiple objectives.""" + + closure = None + if self.use_sam: + closure = self.create_closure(batch, weights_loss) + + if weights_regularizer is not None: + assert isinstance(batch.reference, torch.Tensor) + batch.reference.requires_grad_(True) + + predictions = self._inference(batch) + losses = self.criterion(predictions) + + for i, loss_value in enumerate(losses): + key = "loss_train", i + self.log.setdefault(key, []).append(loss_value.item()) + + if weights_loss is not None: + losses = [ + weight * loss for weight, loss in zip(weights_loss, losses) + ] + + loss = sum(losses) + + if weights_regularizer is not None: + regularizer = self.compute_regularizer(predictions=predictions, + inputs=batch.reference) + assert len(weights_regularizer) == len(regularizer) == len(losses) + loss = loss + sum( + weight * reg + for weight, reg in zip(weights_regularizer, regularizer)) + + loss.backward() + self.optimizer.step(closure) + self.optimizer.zero_grad() + + if weights_regularizer is not None: + for i, (weight, + reg) in enumerate(zip(weights_regularizer, regularizer)): + assert isinstance(weight, float) + self.log.setdefault(("regularizer", i), []).append(reg.item()) + self.log.setdefault(("regularizer_weight", i), + []).append(weight) + + if weights_loss is not None: + for i, weight in enumerate(weights_loss): + assert isinstance(weight, float) or isinstance(weight, int) + self.log.setdefault(("loss_weight", i), []).append(weight) + + # add sum_loss_train + self.log.setdefault(("sum_loss_train",), []).append(loss.item()) + return {"sum_loss_train": loss.item()} + + @torch.no_grad() + def _compute_metrics(self): + # NOTE: We set split_outputs = False when we compute + # validation metrics, otherwise it returns a tuple + # which led to a bug before. + embeddings = {} + self.model.set_split_outputs(False) + for split in self.metrics.splits: + embedding_tensor = self.transform( + self.metrics.datasets[split].neural) + embedding_np = embedding_tensor.cpu().numpy() + assert embedding_np.shape[1] == self.model.num_output + embeddings[split] = embedding_np + + self.model.set_split_outputs(True) + return self.metrics.compute_metrics(embeddings) + + @torch.no_grad() + def validation( + self, + loader: cebra.data.Loader, + logger=None, + weights_loss: Optional[List[float]] = None, + ): + self.model.eval() + total_loss = Meter() + + losses_dict = {} + for _, batch in enumerate(loader): + predictions = self._inference(batch) + losses = self.criterion(predictions) + + if weights_loss is not None: + assert len(weights_loss) == len( + losses + ), "Number of weights should match the number of losses" + losses = [ + weight * loss for weight, loss in zip(weights_loss, losses) + ] + + total_loss.add(sum(losses).item()) + + for i, loss_value in enumerate(losses): + key = "loss_val", i + losses_dict.setdefault(key, []).append(loss_value.item()) + + losses_dict_mean = {k: np.mean(v) for k, v in losses_dict.items()} + stats_val = {**losses_dict_mean} + + if self.metrics is not None: + metrics = self._compute_metrics() + stats_val.update(metrics) + + for key, value in stats_val.items(): + self.log.setdefault(key, []).append(value) + + if logger is not None: + formatted_loss = ', '.join([ + f"{'_'.join(map(str, key))}:{value:.3f}" + for key, value in stats_val.items() + if key[0].startswith("loss") + ]) + formatted_r2 = ', '.join([ + f"{'_'.join(map(str, key))}:{value:.3f}" + for key, value in stats_val.items() + if key[0].startswith("r2") + ]) + logger.info(f"Val: {formatted_loss}") + logger.info(f"Val: {formatted_r2}") + + # add sum_loss_valid + sum_loss_valid = total_loss.average + self.log.setdefault(("sum_loss_val",), []).append(sum_loss_valid) + return stats_val + + @torch.no_grad() + def transform(self, inputs: torch.Tensor) -> torch.Tensor: + offset = self.model.get_offset() + self.model.eval() + X = inputs.cpu().numpy() + X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), mode="edge") + X = torch.from_numpy(X).float().to(self.device) + + if isinstance(self.model.module, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + X = X.transpose(1, 0).unsqueeze(0) + outputs = self.model(X) + + # switch back from (1, C, T) -> (T, C) + if isinstance(outputs, torch.Tensor): + assert outputs.dim() == 3 and outputs.shape[0] == 1 + outputs = outputs.squeeze(0).transpose(1, 0) + elif isinstance(outputs, tuple): + assert all(tensor.dim() == 3 and tensor.shape[0] == 1 + for tensor in outputs) + outputs = ( + output.squeeze(0).transpose(1, 0) for output in outputs) + outputs = tuple(outputs) + else: + raise ValueError("Invalid condition in solver.transform") + else: + # Standard evaluation, (T, C, dt) + outputs = self.model(X) + + return outputs + + +@register("supervised-solver-xcebra") +@dataclasses.dataclass +class SupervisedMultiobjectiveSolverxCEBRA(MultiobjectiveSolverBase): + """Supervised neural network training using the MSE loss. + + This solver can be used as a baseline variant instead of the contrastive solver, + :py:class:`cebra.solver.multiobjective.ContrastiveMultiobjectiveSolverxCEBRA`. + """ + + _variant_name = "supervised-solver-xcebra" + + def _inference(self, batch): + """Compute predictions (discrete/continuous) for the batch.""" + pred_refs = self.model(batch.reference) + prediction_batches = [] + for i, label_data in enumerate(batch.positive): + prediction_batches.append( + cebra.data.Batch(reference=pred_refs[i], + positive=label_data, + negative=None)) + return prediction_batches + + +@register("multiobjective-solver") +@dataclasses.dataclass +class ContrastiveMultiobjectiveSolverxCEBRA(MultiobjectiveSolverBase): + """Multi-objective solver for CEBRA. + + This solver is used for training CEBRA models with multiple objectives. + + See Also: + :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA` + :py:class:`cebra.solver.multiobjective.MultiObjectiveConfig` + :py:class:`cebra.models.multicriterions.MultiCriterions` + """ + + _variant_name = "contrastive-solver-xcebra" + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + pred_refs = self.model(batch.reference) + pred_negs = self.model(batch.negative) + + prediction_batches = [] + for i, positive in enumerate(batch.positive): + pred_pos = self.model(positive) + prediction_batches.append( + cebra.data.Batch(pred_refs[i], pred_pos[i], pred_negs[i])) + + return prediction_batches diff --git a/cebra/solver/regularized.py b/cebra/solver/regularized.py new file mode 100644 index 00000000..41284529 --- /dev/null +++ b/cebra/solver/regularized.py @@ -0,0 +1,105 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Regularized contrastive learning.""" + +from typing import Dict, Optional + +import literate_dataclasses as dataclasses +import torch + +import cebra +import cebra.data +import cebra.models +from cebra.solver import register +from cebra.solver.single_session import SingleSessionSolver + + +@register("regularized-solver") +@dataclasses.dataclass +class RegularizedSolver(SingleSessionSolver): + """Optimize a model using Jacobian Regularizer.""" + + _variant_name = "regularized-solver" + log: Dict = dataclasses.field(default_factory=lambda: ({ + "pos": [], + "neg": [], + "loss": [], + "loss_reg": [], + "temperature": [], + "reg": [], + "reg_lambda": [], + })) + + lambda_JR: Optional[float] = None + + def __post_init__(self): + super().__post_init__() + #TODO: rn we are using the full jacobian. Can be optimized later if needed. + self.jac_regularizer = cebra.models.JacobianReg(n=-1) + + def step(self, batch: cebra.data.Batch) -> dict: + """Perform a single gradient update using the jacobian regularizaiton!. + + Args: + batch: The input samples + + Returns: + Dictionary containing training metrics. + """ + + self.optimizer.zero_grad() + batch.reference.requires_grad = True + prediction = self._inference(batch) + R = self.jac_regularizer(batch.reference, prediction.reference) + + loss, align, uniform = self.criterion(prediction.reference, + prediction.positive, + prediction.negative) + loss_reg = loss + self.lambda_JR * R + + loss_reg.backward() + self.optimizer.step() + self.history.append(loss.item()) + stats = dict(pos=align.item(), + neg=uniform.item(), + loss=loss.item(), + loss_reg=loss_reg.item(), + reg=R.item(), + temperature=self.criterion.temperature, + reg_lambda=(self.lambda_JR * R).item()) + + for key, value in stats.items(): + self.log[key].append(value) + return stats + + +def _prepare_inputs(inputs): + if not isinstance(inputs, torch.Tensor): + inputs = torch.from_numpy(inputs) + inputs.requires_grad_(True) + return inputs + + +def _prepare_model(model): + for p in model.parameters(): + p.requires_grad_(False) + return model diff --git a/cebra/solver/schedulers.py b/cebra/solver/schedulers.py new file mode 100644 index 00000000..1da637af --- /dev/null +++ b/cebra/solver/schedulers.py @@ -0,0 +1,97 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import dataclasses +from typing import List + +import cebra.registry + +cebra.registry.add_helper_functions(__name__) + +__all__ = ["Scheduler", "ConstantScheduler", "LinearScheduler", "LinearRampUp"] + + +@dataclasses.dataclass +class Scheduler(abc.ABC): + + def __post_init__(self): + pass + + @abc.abstractmethod + def get_weights(self): + pass + + +@register("constant-weight") +@dataclasses.dataclass +class ConstantScheduler(Scheduler): + initial_weights: List[float] + + def __post_init__(self): + super().__post_init__() + + def get_weights(self): + weights = self.initial_weights + if len(weights) == 0: + weights = None + return weights + + +@register("linear-scheduler") +@dataclasses.dataclass +class LinearScheduler(Scheduler): + n_splits: int + step_to_switch_on_reg: int + step_to_switch_off_reg: int + start_weight: float + end_weight: float + stay_constant_after_switch_off: bool = False + + def __post_init__(self): + super().__post_init__() + assert self.step_to_switch_off_reg > self.step_to_switch_on_reg + + def get_weights(self, step): + if self.step_to_switch_on_reg is not None: + if step >= self.step_to_switch_on_reg and step <= self.step_to_switch_off_reg: + interpolation_factor = min( + 1.0, (step - self.step_to_switch_on_reg) / + (self.step_to_switch_off_reg - self.step_to_switch_on_reg)) + weight = self.start_weight + ( + self.end_weight - self.start_weight) * interpolation_factor + weights = [weight] * self.n_splits + elif self.stay_constant_after_switch_off and step > self.step_to_switch_off_reg: + weight = self.end_weight + weights = [weight] * self.n_splits + else: + weights = None + + return weights + + +@register("linear-ramp-up") +@dataclasses.dataclass +class LinearRampUp(LinearScheduler): + + def __post_init__(self): + super().__post_init__() + self.stay_constant_after_switch_off = True diff --git a/docs/.gitignore b/docs/.gitignore index a48ebfca..f7176a04 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,3 @@ build/ page/ +root/static diff --git a/docs/Makefile b/docs/Makefile index 26c260d3..9252ed72 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -31,7 +31,10 @@ clean: # Checkout the source repository for CEBRA figures. Note that this requires SSH access # and might prompt you for an SSH key. source/cebra-figures: - git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-figures.git source/cebra-figures + cd $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) && git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-figures.git source/cebra-figures + +source/demo_notebooks: + cd $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) && git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-demos.git source/demo_notebooks source/demo_notebooks: git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-demos.git source/demo_notebooks @@ -44,7 +47,7 @@ demos: source/demo_notebooks cd source/demo_notebooks && git pull --ff-only origin main source/assets: - git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-assets.git source/assets + cd $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) && git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-assets.git source/assets assets: source/assets cd source/assets && git pull --ff-only origin main diff --git a/docs/source/api.rst b/docs/source/api.rst index 8989337f..846602f1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -38,6 +38,9 @@ these components in other contexts and research code bases. api/pytorch/distributions api/pytorch/models api/pytorch/helpers + api/pytorch/multiobjective + api/pytorch/regularized + api/pytorch/attribution .. toctree:: :hidden: diff --git a/docs/source/api/pytorch/attribution.rst b/docs/source/api/pytorch/attribution.rst new file mode 100644 index 00000000..6efb043f --- /dev/null +++ b/docs/source/api/pytorch/attribution.rst @@ -0,0 +1,21 @@ +=================== +Attribution Methods +=================== + +.. automodule:: cebra.attribution + :members: + :show-inheritance: + +Different attribution methods +----------------------------- + +.. automodule:: cebra.attribution.attribution_models + :members: + :show-inheritance: + +Jacobian-based attribution +-------------------------- + +.. automodule:: cebra.attribution.jacobian_attribution + :members: + :show-inheritance: diff --git a/docs/source/api/pytorch/models.rst b/docs/source/api/pytorch/models.rst index ee3455bc..3fe2219b 100644 --- a/docs/source/api/pytorch/models.rst +++ b/docs/source/api/pytorch/models.rst @@ -43,12 +43,8 @@ Layers and model building blocks :show-inheritance: Multi-objective models -~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: cebra.models.multiobjective - :members: - :private-members: - :show-inheritance: - -.. - - projector +The multi-objective interface was moved to a separate section beginning with CEBRA 0.6.0. +Please see the :doc:`Multi-objective models ` section +for all details, both on the old and new API interface. diff --git a/docs/source/api/pytorch/multiobjective.rst b/docs/source/api/pytorch/multiobjective.rst new file mode 100644 index 00000000..c959cfa1 --- /dev/null +++ b/docs/source/api/pytorch/multiobjective.rst @@ -0,0 +1,15 @@ +====================== +Multi-objective models +====================== + +.. automodule:: cebra.solver.multiobjective + :members: + :show-inheritance: + +.. automodule:: cebra.models.multicriterions + :members: + :show-inheritance: + +.. automodule:: cebra.models.multiobjective + :members: + :show-inheritance: diff --git a/docs/source/api/pytorch/regularized.rst b/docs/source/api/pytorch/regularized.rst new file mode 100644 index 00000000..7da94603 --- /dev/null +++ b/docs/source/api/pytorch/regularized.rst @@ -0,0 +1,24 @@ +================================ +Regularized Contrastive Learning +================================ + +Regularized solvers +-------------------- + +.. automodule:: cebra.solver.regularized + :members: + :show-inheritance: + +Schedulers +---------- + +.. automodule:: cebra.solver.schedulers + :members: + :show-inheritance: + +Jacobian Regularization +----------------------- + +.. automodule:: cebra.models.jacobian_regularizer + :members: + :show-inheritance: diff --git a/docs/source/conf.py b/docs/source/conf.py index 80399e5f..83c41fad 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -128,7 +128,7 @@ def get_years(start_year=2021): autodoc_member_order = "bysource" autodoc_mock_imports = [ "torch", "nlb_tools", "tqdm", "h5py", "pandas", "matplotlib", "plotly", - "joblib", "scikit-learn", "scipy", "requests", "sklearn" + "cvxpy", "captum", "joblib", "scikit-learn", "scipy", "requests", "sklearn" ] # autodoc_typehints = "none" @@ -139,9 +139,18 @@ def get_years(start_year=2021): # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [ - "**/todo", "**/src", "cebra-figures/figures.rst", "cebra-figures/*.rst", - "*/cebra-figures/*.rst", "*/demo_notebooks/README.rst", - "demo_notebooks/README.rst" + "**/todo", + "**/src", + "cebra-figures/figures.rst", + "cebra-figures/*.rst", + "*/cebra-figures/*.rst", + "*/demo_notebooks/README.rst", + "demo_notebooks/README.rst", + # TODO(stes): Remove this from the assets repo, then remove here + "_static/figures_usage.ipynb", + "*/_static/figures_usage.ipynb", + "assets/**/*.ipynb", + "*/assets/**/*.ipynb" ] # -- Options for HTML output ------------------------------------------------- @@ -194,7 +203,7 @@ def get_years(start_year=2021): ], "collapse_navigation": False, "navigation_depth": 1, - "show_nav_level": 2, + "show_nav_level": 1, "navbar_align": "content", "show_prev_next": False, "navbar_end": ["theme-switcher", "navbar-icon-links.html"], diff --git a/reinstall.sh b/reinstall.sh index 422e5d17..ea8981b9 100755 --- a/reinstall.sh +++ b/reinstall.sh @@ -15,7 +15,7 @@ pip uninstall -y cebra # Get version info after uninstalling --- this will automatically get the # most recent version based on the source code in the current directory. # $(tools/get_cebra_version.sh) -VERSION=0.5.0 +VERSION=0.6.0a1 echo "Upgrading to CEBRA v${VERSION}" # Upgrade the build system (PEP517/518 compatible) diff --git a/setup.cfg b/setup.cfg index 40383b89..7faff998 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,6 +63,9 @@ integrations = pandas plotly seaborn + captum + cvxpy + scikit-image docs = sphinx sphinx-gallery diff --git a/tests/test_attribution.py b/tests/test_attribution.py new file mode 100644 index 00000000..cfb8ad7a --- /dev/null +++ b/tests/test_attribution.py @@ -0,0 +1,214 @@ +import numpy as np +import pytest +import torch + +import cebra.attribution._jacobian +import cebra.attribution.jacobian_attribution as jacobian_attribution +from cebra.attribution import attribution_models +from cebra.models import Model + + +class DummyModel(Model): + + def __init__(self): + super().__init__(num_input=10, num_output=5) + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + def get_offset(self): + return None + + +@pytest.fixture +def model(): + return DummyModel() + + +@pytest.fixture +def input_data(): + return torch.randn(100, 10) + + +def test_neuron_gradient_method(model, input_data): + attribution = attribution_models.NeuronGradientMethod(model=model, + input_data=input_data, + output_dimension=5) + + result = attribution.compute_attribution_map() + + assert 'neuron-gradient' in result + assert 'neuron-gradient-convabs' in result + assert result['neuron-gradient'].shape == (100, 5, 10) + + +def test_neuron_gradient_shap_method(model, input_data): + attribution = attribution_models.NeuronGradientShapMethod( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map(baselines="zeros") + + assert 'neuron-gradient-shap' in result + assert 'neuron-gradient-shap-convabs' in result + assert result['neuron-gradient-shap'].shape == (100, 5, 10) + + with pytest.raises(NotImplementedError): + attribution.compute_attribution_map(baselines="invalid") + + +def test_feature_ablation_method(model, input_data): + attribution = attribution_models.FeatureAblationMethod( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map() + + assert 'feature-ablation' in result + assert 'feature-ablation-convabs' in result + assert result['feature-ablation'].shape == (100, 5, 10) + + +def test_integrated_gradients_method(model, input_data): + attribution = attribution_models.IntegratedGradientsMethod( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map() + + assert 'integrated-gradients' in result + assert 'integrated-gradients-convabs' in result + assert result['integrated-gradients'].shape == (100, 5, 10) + + +def test_batched_methods(model, input_data): + # Test batched version of NeuronGradientMethod + attribution = attribution_models.NeuronGradientMethodBatched( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map(batch_size=32) + assert 'neuron-gradient' in result + assert result['neuron-gradient'].shape == (100, 5, 10) + + # Test batched version of IntegratedGradientsMethod + attribution = attribution_models.IntegratedGradientsMethodBatched( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map(batch_size=32) + assert 'integrated-gradients' in result + assert result['integrated-gradients'].shape == (100, 5, 10) + + +def test_compute_metrics(): + attribution = attribution_models.AttributionMap(model=None, input_data=None) + + attribution_map = np.array([0.1, 0.8, 0.3, 0.9, 0.2]) + ground_truth = np.array([False, True, False, True, False]) + + metrics = attribution.compute_metrics(attribution_map, ground_truth) + + assert 'max_connected' in metrics + assert 'mean_connected' in metrics + assert 'min_connected' in metrics + assert 'max_nonconnected' in metrics + assert 'mean_nonconnected' in metrics + assert 'min_nonconnected' in metrics + assert 'gap_max' in metrics + assert 'gap_mean' in metrics + assert 'gap_min' in metrics + assert 'gap_minmax' in metrics + assert 'max_jacobian' in metrics + assert 'min_jacobian' in metrics + + +def test_compute_attribution_score(): + attribution = attribution_models.AttributionMap(model=None, input_data=None) + + attribution_map = np.array([0.1, 0.8, 0.3, 0.9, 0.2]) + ground_truth = np.array([False, True, False, True, False]) + + score = attribution.compute_attribution_score(attribution_map, ground_truth) + assert isinstance(score, float) + assert 0 <= score <= 1 + + +def test_jacobian_computation(): + # Create a simple model and input for testing + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), + torch.nn.Linear(5, 3)) + input_data = torch.randn(100, 10, requires_grad=True) + + # Test basic Jacobian computation + jf, jhatg = jacobian_attribution.get_attribution_map(model=model, + input_data=input_data, + double_precision=True, + convert_to_numpy=True) + + # Check shapes + assert jf.shape == (100, 3, 10) # (batch_size, output_dim, input_dim) + assert jhatg.shape == (100, 10, 3) # (batch_size, input_dim, output_dim) + + +def test_tensor_conversion(): + # Test CPU and double precision conversion + test_tensors = [torch.randn(10, 5), torch.randn(5, 3)] + + converted = cebra.attribution._jacobian.tensors_to_cpu_and_double( + test_tensors) + + for tensor in converted: + assert tensor.device.type == "cpu" + assert tensor.dtype == torch.float64 + + # Only test CUDA conversion if CUDA is available + if torch.cuda.is_available(): + cuda_tensors = cebra.attribution._jacobian.tensors_to_cuda( + test_tensors, cuda_device="cuda") + for tensor in cuda_tensors: + assert tensor.is_cuda + else: + # Skip CUDA test with a message + pytest.skip("CUDA not available - skipping CUDA conversion test") + + +def test_jacobian_with_hybrid_solver(): + # Test Jacobian computation with hybrid solver + class HybridModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(10, 3) + + def forward(self, x): + return self.fc1(x), self.fc2(x) + + model = HybridModel() + # Move model to CPU to ensure test works everywhere + model = model.cpu() + input_data = torch.randn(50, 10, requires_grad=True) + + # Ensure input is on CPU + input_data = input_data.cpu() + + jacobian = cebra.attribution._jacobian.compute_jacobian( + model=model, + input_vars=[input_data], + hybrid_solver=True, + convert_to_numpy=True, + cuda_device=None # Explicitly set to None to use CPU + ) + + # Check shape (batch_size, output_dim, input_dim) + assert jacobian.shape == (50, 8, 10) # 8 = 5 + 3 concatenated outputs + + +def test_attribution_map_transforms(): + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), + torch.nn.Linear(5, 3)) + input_data = torch.randn(100, 10) + + # Test different aggregation methods + for aggregate in ["mean", "sum", "max"]: + jf, jhatg = jacobian_attribution.get_attribution_map( + model=model, input_data=input_data, aggregate=aggregate) + assert isinstance(jf, np.ndarray) + assert isinstance(jhatg, np.ndarray) diff --git a/tests/test_integration_xcebra.py b/tests/test_integration_xcebra.py new file mode 100644 index 00000000..4e647916 --- /dev/null +++ b/tests/test_integration_xcebra.py @@ -0,0 +1,152 @@ +import pickle + +import pytest +import torch + +import cebra +import cebra.attribution +import cebra.data +import cebra.models +import cebra.solver +from cebra.data import ContrastiveMultiObjectiveLoader +from cebra.data import DatasetxCEBRA +from cebra.solver import MultiObjectiveConfig +from cebra.solver.schedulers import LinearRampUp + + +@pytest.fixture +def synthetic_data(): + import tempfile + import urllib.request + from pathlib import Path + + url = "https://cebra.fra1.digitaloceanspaces.com/xcebra_synthetic_data.pkl" + + # Create a persistent temp directory specific to this test + temp_dir = Path(tempfile.gettempdir()) / "cebra_test_data" + temp_dir.mkdir(exist_ok=True) + filepath = temp_dir / "synthetic_data.pkl" + + if not filepath.exists(): + urllib.request.urlretrieve(url, filepath) + + with filepath.open('rb') as file: + return pickle.load(file) + + +@pytest.fixture +def device(): + return "cuda" if torch.cuda.is_available() else "cpu" + + +def test_synthetic_data_training(synthetic_data, device): + # Setup data + neurons = synthetic_data['neurons'] + latents = synthetic_data['latents'] + n_latents = latents.shape[1] + Z1 = synthetic_data['Z1'] + Z2 = synthetic_data['Z2'] + gt_attribution_map = synthetic_data['gt_attribution_map'] + data = DatasetxCEBRA(neurons, Z1=Z1, Z2=Z2) + + # Configure training with reduced steps + TOTAL_STEPS = 50 # Reduced from 2000 for faster testing + loader = ContrastiveMultiObjectiveLoader(dataset=data, + num_steps=TOTAL_STEPS, + batch_size=512).to(device) + + config = MultiObjectiveConfig(loader) + config.set_slice(0, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time", time_offset=1) + config.push() + + config.set_slice(3, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time_delta", time_delta=1, label_name="Z2") + config.push() + + config.finalize() + + # Initialize model and solver + neural_model = cebra.models.init( + name="offset1-model-mse-clip-5-5", + num_neurons=data.neural.shape[1], + num_units=256, + num_output=n_latents, + ).to(device) + + data.configure_for(neural_model) + + opt = torch.optim.Adam( + list(neural_model.parameters()) + list(config.criterion.parameters()), + lr=3e-4, + weight_decay=0, + ) + + regularizer = cebra.models.jacobian_regularizer.JacobianReg() + + solver = cebra.solver.init( + name="multiobjective-solver", + model=neural_model, + feature_ranges=config.feature_ranges, + regularizer=regularizer, + renormalize=False, + use_sam=False, + criterion=config.criterion, + optimizer=opt, + tqdm_on=False, + ).to(device) + + # Train model with reduced steps for regularizer + weight_scheduler = LinearRampUp( + n_splits=2, + step_to_switch_on_reg=25, # Reduced from 2500 + step_to_switch_off_reg=40, # Reduced from 15000 + start_weight=0., + end_weight=0.01, + stay_constant_after_switch_off=True) + + solver.fit( + loader=loader, + valid_loader=None, + log_frequency=None, + scheduler_regularizer=weight_scheduler, + scheduler_loss=None, + ) + + # Basic test that model runs and produces output + solver.model.split_outputs = False + embedding = solver.model(data.neural.to(device)).detach().cpu() + + # Verify output dimensions + assert embedding.shape[1] == n_latents, "Incorrect embedding dimension" + assert not torch.isnan(embedding).any(), "NaN values in embedding" + + # Test attribution map functionality + data.neural.requires_grad_(True) + method = cebra.attribution.init(name="jacobian-based", + model=solver.model, + input_data=data.neural, + output_dimension=solver.model.num_output) + + result = method.compute_attribution_map() + jfinv = abs(result['jf-inv-lsq']).mean(0) + + # Verify attribution map output + assert not torch.isnan( + torch.tensor(jfinv)).any(), "NaN values in attribution map" + assert jfinv.shape == gt_attribution_map.shape, "Incorrect attribution map shape" + + # Test split outputs functionality + solver.model.split_outputs = True + embedding_split = solver.model(data.neural.to(device)) + Z1_hat = embedding_split[0].detach().cpu() + Z2_hat = embedding_split[1].detach().cpu() + + # TODO(stes): Right now, this results 6D output vs. 3D as expected. Need to double check + # the API docs on the desired behavior here, both could be fine... + # assert Z1_hat.shape == Z1.shape, f"Incorrect Z1 embedding dimension: {Z1_hat.shape}" + assert Z2_hat.shape == Z2.shape, f"Incorrect Z2 embedding dimension: {Z2_hat.shape}" + assert not torch.isnan(Z1_hat).any(), "NaN values in Z1 embedding" + assert not torch.isnan(Z2_hat).any(), "NaN values in Z2 embedding" diff --git a/tests/test_models.py b/tests/test_models.py index d41dc7ab..658cc467 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -90,6 +90,10 @@ def test_offset_models(model_name, batch_size, input_length): def test_multiobjective(): + # NOTE(stes): This test is deprecated and will be removed in a future version. + # As of CEBRA 0.6.0, the multi objective models are tested separately in + # test_multiobjective.py. + class TestModel(cebra.models.Model): def __init__(self): diff --git a/tests/test_multiobjective.py b/tests/test_multiobjective.py new file mode 100644 index 00000000..a4c601ac --- /dev/null +++ b/tests/test_multiobjective.py @@ -0,0 +1,145 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings + +import pytest +import torch + +import cebra +from cebra.data import ContrastiveMultiObjectiveLoader +from cebra.data import DatasetxCEBRA +from cebra.solver import MultiObjectiveConfig + + +@pytest.fixture +def config(): + neurons = torch.randn(100, 5) + behavior1 = torch.randn(100, 2) + behavior2 = torch.randn(100, 1) + data = DatasetxCEBRA(neurons, behavior1=behavior1, behavior2=behavior2) + loader = ContrastiveMultiObjectiveLoader(dataset=data, + num_steps=1, + batch_size=24) + return MultiObjectiveConfig(loader) + + +def test_imports(): + pass + + +def test_add_data(config): + config.set_slice(0, 10) + config.set_loss('loss_name', param1='value1') + config.set_distribution('distribution_name', param2='value2') + config.push() + + assert len(config.total_info) == 1 + assert config.total_info[0]['slice'] == (0, 10) + assert config.total_info[0]['losses'] == { + "name": 'loss_name', + "kwargs": { + 'param1': 'value1' + } + } + assert config.total_info[0]['distributions'] == { + "name": 'distribution_name', + "kwargs": { + 'param2': 'value2' + } + } + + +def test_overwriting_key_warning(config): + with warnings.catch_warnings(record=True) as w: + config.set_slice(0, 10) + config.set_slice(10, 20) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Configuration key already exists" in str(w[-1].message) + + +def test_missing_slice_error(config): + with pytest.raises(RuntimeError, match="Slice configuration is missing"): + config.set_loss('loss_name', param1='value1') + config.set_distribution('distribution_name', param2='value2') + config.push() + + +def test_missing_distributions_error(config): + with pytest.raises(RuntimeError, + match="Distributions configuration is missing"): + config.set_slice(0, 10) + config.set_loss('loss_name', param1='value1') + config.push() + + +def test_missing_losses_error(config): + with pytest.raises(RuntimeError, match="Losses configuration is missing"): + config.set_slice(0, 10) + config.set_distribution('distribution_name', param2='value2') + config.push() + + +def test_finalize(config): + config.set_slice(0, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time", time_offset=1) + config.push() + + config.set_slice(3, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time_delta", time_delta=3, label_name="behavior2") + config.push() + + config.finalize() + + assert len(config.losses) == 2 + assert config.losses[0]['indices'] == (0, 6) + assert config.losses[1]['indices'] == (3, 6) + + assert len(config.feature_ranges) == 2 + assert config.feature_ranges[0] == slice(0, 6) + assert config.feature_ranges[1] == slice(3, 6) + + assert len(config.loader.distributions) == 2 + assert isinstance(config.loader.distributions[0], + cebra.distributions.continuous.TimeContrastive) + assert config.loader.distributions[0].time_offset == 1 + + assert isinstance(config.loader.distributions[1], + cebra.distributions.continuous.TimedeltaDistribution) + assert config.loader.distributions[1].time_delta == 3 + + +def test_non_unique_feature_ranges_error(config): + config.set_slice(0, 10) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time", time_offset=1) + config.push() + + config.set_slice(0, 10) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time_delta", time_delta=3, label_name="behavior2") + config.push() + + with pytest.raises(RuntimeError, match="Feature ranges are not unique"): + config.finalize()