diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 4fbc871a..cc07e8df 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -110,7 +110,7 @@ def _consistency_scores( Args: embeddings: List of embedding matrices. dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be - associated to the same dataset. + associated to the same dataset. Returns: List of the consistencies for each embeddings pair (first element) and @@ -145,6 +145,7 @@ def _consistency_datasets( embeddings: List[Union[npt.NDArray, torch.Tensor]], dataset_ids: Optional[List[Union[int, str, float]]], labels: List[Union[npt.NDArray, torch.Tensor]], + num_discretization_bins: int = 100 ) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: """Compute consistency between embeddings from different datasets. @@ -158,9 +159,14 @@ def _consistency_datasets( Args: embeddings: List of embedding matrices. dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be - associated to the same dataset. + associated to the same dataset. labels: List of labels corresponding to each embedding and to use for alignment between them. + num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used + for embedding alignment. Also see the ``n_bins`` argument in + :py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this + parameter is used internally. This argument is only used if ``labels`` + is not ``None`` and the given labels are continuous and not already discrete. Returns: A list of scores obtained between embeddings from different datasets (first element), @@ -203,7 +209,7 @@ def _consistency_datasets( # NOTE(celia): with default values normalized=True and n_bins = 100 aligned_embeddings = cebra_sklearn_helpers.align_embeddings( - embeddings, labels) + embeddings, labels, n_bins=num_discretization_bins) scores, pairs = _consistency_scores(aligned_embeddings, datasets=dataset_ids) between_dataset = [p[0] != p[1] for p in pairs] @@ -303,6 +309,7 @@ def consistency_score( between: Optional[Literal["datasets", "runs"]] = None, labels: Optional[List[Union[npt.NDArray, torch.Tensor]]] = None, dataset_ids: Optional[List[Union[int, str, float]]] = None, + num_discretization_bins: int = 100 ) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: """Compute the consistency score between embeddings, either between runs or between datasets. @@ -320,6 +327,12 @@ def consistency_score( *Consistency between runs* means the consistency between embeddings obtained from multiple models trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings obtained from models trained on **different datasets**, such as different animals, sessions, etc. + num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used + for embedding alignment. Also see the ``n_bins`` argument in + :py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this + parameter is used internally. This argument is only used if ``labels`` + is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels + are continuous and not already discrete. Returns: The list of scores computed between the embeddings (first returns), the list of pairs corresponding @@ -356,12 +369,16 @@ def consistency_score( if labels is not None: raise ValueError( f"No labels should be provided for between-runs consistency.") - scores, pairs, datasets = _consistency_runs(embeddings=embeddings, - dataset_ids=dataset_ids) + scores, pairs, datasets = _consistency_runs( + embeddings=embeddings, + dataset_ids=dataset_ids, + ) elif between == "datasets": - scores, pairs, datasets = _consistency_datasets(embeddings=embeddings, - dataset_ids=dataset_ids, - labels=labels) + scores, pairs, datasets = _consistency_datasets( + embeddings=embeddings, + dataset_ids=dataset_ids, + labels=labels, + num_discretization_bins=num_discretization_bins) else: raise NotImplementedError( f"Invalid comparison, got between={between}, expects either datasets or runs." diff --git a/docs/source/api.rst b/docs/source/api.rst index 83a1554d..0b2945c0 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -25,6 +25,7 @@ these components in other contexts and research code bases. api/sklearn/cebra api/sklearn/metrics api/sklearn/decoder + api/sklearn/helpers diff --git a/docs/source/api/sklearn/helpers.rst b/docs/source/api/sklearn/helpers.rst new file mode 100644 index 00000000..0fbbd796 --- /dev/null +++ b/docs/source/api/sklearn/helpers.rst @@ -0,0 +1,7 @@ +Helper functions +---------------- + +.. automodule:: cebra.integrations.sklearn.helpers + :show-inheritance: + :members: +