From 3bb43cbef38cfd430f295e7d33ffcdd26fa652f8 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 24 Jun 2023 13:41:19 +0200 Subject: [PATCH 1/3] Expose n_bins argument from align_embeddings --- cebra/integrations/sklearn/metrics.py | 31 ++++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 4fbc871a..c002665e 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -110,7 +110,7 @@ def _consistency_scores( Args: embeddings: List of embedding matrices. dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be - associated to the same dataset. + associated to the same dataset. Returns: List of the consistencies for each embeddings pair (first element) and @@ -145,6 +145,7 @@ def _consistency_datasets( embeddings: List[Union[npt.NDArray, torch.Tensor]], dataset_ids: Optional[List[Union[int, str, float]]], labels: List[Union[npt.NDArray, torch.Tensor]], + num_discretization_bins: int = 100 ) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: """Compute consistency between embeddings from different datasets. @@ -158,9 +159,13 @@ def _consistency_datasets( Args: embeddings: List of embedding matrices. dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be - associated to the same dataset. + associated to the same dataset. labels: List of labels corresponding to each embedding and to use for alignment between them. + num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used + for embedding alignment. Also see the ``n_bins`` argument in :py:mod:`~..helpers.align_embeddings` + for more information on how this parameter is used internally. This argument is only used if ``labels`` + is not ``None`` and the given labels are continuous and not already discrete. Returns: A list of scores obtained between embeddings from different datasets (first element), @@ -203,7 +208,7 @@ def _consistency_datasets( # NOTE(celia): with default values normalized=True and n_bins = 100 aligned_embeddings = cebra_sklearn_helpers.align_embeddings( - embeddings, labels) + embeddings, labels, n_bins=num_discretization_bins) scores, pairs = _consistency_scores(aligned_embeddings, datasets=dataset_ids) between_dataset = [p[0] != p[1] for p in pairs] @@ -303,6 +308,7 @@ def consistency_score( between: Optional[Literal["datasets", "runs"]] = None, labels: Optional[List[Union[npt.NDArray, torch.Tensor]]] = None, dataset_ids: Optional[List[Union[int, str, float]]] = None, + num_discretization_bins: int = 100 ) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: """Compute the consistency score between embeddings, either between runs or between datasets. @@ -320,6 +326,11 @@ def consistency_score( *Consistency between runs* means the consistency between embeddings obtained from multiple models trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings obtained from models trained on **different datasets**, such as different animals, sessions, etc. + num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used + for embedding alignment. Also see the ``n_bins`` argument in :py:mod:`~..helpers.align_embeddings` + for more information on how this parameter is used internally. This argument is only used if ``labels`` + is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels + are continuous and not already discrete. Returns: The list of scores computed between the embeddings (first returns), the list of pairs corresponding @@ -356,12 +367,16 @@ def consistency_score( if labels is not None: raise ValueError( f"No labels should be provided for between-runs consistency.") - scores, pairs, datasets = _consistency_runs(embeddings=embeddings, - dataset_ids=dataset_ids) + scores, pairs, datasets = _consistency_runs( + embeddings=embeddings, + dataset_ids=dataset_ids, + ) elif between == "datasets": - scores, pairs, datasets = _consistency_datasets(embeddings=embeddings, - dataset_ids=dataset_ids, - labels=labels) + scores, pairs, datasets = _consistency_datasets( + embeddings=embeddings, + dataset_ids=dataset_ids, + labels=labels, + num_discretization_bins=num_discretization_bins) else: raise NotImplementedError( f"Invalid comparison, got between={between}, expects either datasets or runs." From 0faae2b17845b241a044860de818e859feb3b0c2 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 24 Jun 2023 13:51:27 +0200 Subject: [PATCH 2/3] Fix docs --- cebra/integrations/sklearn/metrics.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index c002665e..cc07e8df 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -163,8 +163,9 @@ def _consistency_datasets( labels: List of labels corresponding to each embedding and to use for alignment between them. num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used - for embedding alignment. Also see the ``n_bins`` argument in :py:mod:`~..helpers.align_embeddings` - for more information on how this parameter is used internally. This argument is only used if ``labels`` + for embedding alignment. Also see the ``n_bins`` argument in + :py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this + parameter is used internally. This argument is only used if ``labels`` is not ``None`` and the given labels are continuous and not already discrete. Returns: @@ -327,8 +328,9 @@ def consistency_score( trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings obtained from models trained on **different datasets**, such as different animals, sessions, etc. num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used - for embedding alignment. Also see the ``n_bins`` argument in :py:mod:`~..helpers.align_embeddings` - for more information on how this parameter is used internally. This argument is only used if ``labels`` + for embedding alignment. Also see the ``n_bins`` argument in + :py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this + parameter is used internally. This argument is only used if ``labels`` is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels are continuous and not already discrete. From d24daeed9af11c1ff07e567bcc8fc80ce7ce344c Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 12 Jul 2023 04:36:57 +0200 Subject: [PATCH 3/3] Add sklearn helper functions to public docs --- docs/source/api.rst | 1 + docs/source/api/sklearn/helpers.rst | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 docs/source/api/sklearn/helpers.rst diff --git a/docs/source/api.rst b/docs/source/api.rst index 83a1554d..0b2945c0 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -25,6 +25,7 @@ these components in other contexts and research code bases. api/sklearn/cebra api/sklearn/metrics api/sklearn/decoder + api/sklearn/helpers diff --git a/docs/source/api/sklearn/helpers.rst b/docs/source/api/sklearn/helpers.rst new file mode 100644 index 00000000..0fbbd796 --- /dev/null +++ b/docs/source/api/sklearn/helpers.rst @@ -0,0 +1,7 @@ +Helper functions +---------------- + +.. automodule:: cebra.integrations.sklearn.helpers + :show-inheritance: + :members: +