Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _consistency_scores(
Args:
embeddings: List of embedding matrices.
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
associated to the same dataset.
associated to the same dataset.

Returns:
List of the consistencies for each embeddings pair (first element) and
Expand Down Expand Up @@ -145,6 +145,7 @@ def _consistency_datasets(
embeddings: List[Union[npt.NDArray, torch.Tensor]],
dataset_ids: Optional[List[Union[int, str, float]]],
labels: List[Union[npt.NDArray, torch.Tensor]],
num_discretization_bins: int = 100
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
"""Compute consistency between embeddings from different datasets.

Expand All @@ -158,9 +159,14 @@ def _consistency_datasets(
Args:
embeddings: List of embedding matrices.
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
associated to the same dataset.
associated to the same dataset.
labels: List of labels corresponding to each embedding and to use for alignment
between them.
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
for embedding alignment. Also see the ``n_bins`` argument in
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
parameter is used internally. This argument is only used if ``labels``
is not ``None`` and the given labels are continuous and not already discrete.

Returns:
A list of scores obtained between embeddings from different datasets (first element),
Expand Down Expand Up @@ -203,7 +209,7 @@ def _consistency_datasets(

# NOTE(celia): with default values normalized=True and n_bins = 100
aligned_embeddings = cebra_sklearn_helpers.align_embeddings(
embeddings, labels)
embeddings, labels, n_bins=num_discretization_bins)
scores, pairs = _consistency_scores(aligned_embeddings,
datasets=dataset_ids)
between_dataset = [p[0] != p[1] for p in pairs]
Expand Down Expand Up @@ -303,6 +309,7 @@ def consistency_score(
between: Optional[Literal["datasets", "runs"]] = None,
labels: Optional[List[Union[npt.NDArray, torch.Tensor]]] = None,
dataset_ids: Optional[List[Union[int, str, float]]] = None,
num_discretization_bins: int = 100
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
"""Compute the consistency score between embeddings, either between runs or between datasets.

Expand All @@ -320,6 +327,12 @@ def consistency_score(
*Consistency between runs* means the consistency between embeddings obtained from multiple models
trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings
obtained from models trained on **different datasets**, such as different animals, sessions, etc.
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
for embedding alignment. Also see the ``n_bins`` argument in
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
parameter is used internally. This argument is only used if ``labels``
is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels
are continuous and not already discrete.

Returns:
The list of scores computed between the embeddings (first returns), the list of pairs corresponding
Expand Down Expand Up @@ -356,12 +369,16 @@ def consistency_score(
if labels is not None:
raise ValueError(
f"No labels should be provided for between-runs consistency.")
scores, pairs, datasets = _consistency_runs(embeddings=embeddings,
dataset_ids=dataset_ids)
scores, pairs, datasets = _consistency_runs(
embeddings=embeddings,
dataset_ids=dataset_ids,
)
elif between == "datasets":
scores, pairs, datasets = _consistency_datasets(embeddings=embeddings,
dataset_ids=dataset_ids,
labels=labels)
scores, pairs, datasets = _consistency_datasets(
embeddings=embeddings,
dataset_ids=dataset_ids,
labels=labels,
num_discretization_bins=num_discretization_bins)
else:
raise NotImplementedError(
f"Invalid comparison, got between={between}, expects either datasets or runs."
Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ these components in other contexts and research code bases.
api/sklearn/cebra
api/sklearn/metrics
api/sklearn/decoder
api/sklearn/helpers



Expand Down
7 changes: 7 additions & 0 deletions docs/source/api/sklearn/helpers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Helper functions
----------------

.. automodule:: cebra.integrations.sklearn.helpers
:show-inheritance:
:members: