From 9188fb189aedec1c7ebf8bb87804bc007273e97b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 12 Sep 2022 18:54:22 +0200 Subject: [PATCH 01/15] Add ICA.get_explained_variance_ratio() --- doc/changes/latest.inc | 2 + mne/html_templates/repr/ica.html.jinja | 4 - mne/preprocessing/ica.py | 109 ++++++++++++++++-- mne/preprocessing/tests/test_ica.py | 33 +++++- .../40_artifact_correction_ica.py | 14 +++ 5 files changed, 146 insertions(+), 16 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index d4442f9959d..1cbd1384bfc 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -43,6 +43,7 @@ Enhancements - Add :func:`mne.chpi.get_active_chpi` to retrieve the number of active hpi coils for each time point (:gh:`11122` by `Eduard Ort`_) - Add example of how to obtain time-frequency decomposition using narrow bandpass Hilbert transforms to :ref:`ex-tfr-comparison` (:gh:`11116` by `Alex Rockhill`_) - Add ``==`` and ``!=`` comparison between `mne.Projection` objects (:gh:`11147` by `Mathieu Scheltienne`_) +- :class:`mne.preprocessing.ICA` gained a new method, :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio`, that allows the retrieval of the proportion of variance explained by ICA components (:gh:`11141` by `Richard Höchenberger`_) Bugs ~~~~ @@ -56,6 +57,7 @@ Bugs - Fix bug in :func:`mne.viz.plot_filter` when plotting filters created using ``output='ba'`` mode with ``compensation`` turned on. (:gh:`11040` by `Marian Dovgialo`_) - Fix bug in :func:`mne.io.read_raw_bti` where EEG, EMG, and H/VEOG channels were not detected properly, and many non-ECG channels were called ECG. The logic has been improved, and any channels of unknown type are now labeled as ``misc`` (:gh:`11102` by `Eric Larson`_) - Fix bug in :func:`mne.viz.plot_topomap` when providing ``sphere="eeglab"`` (:gh:`11081` by `Mathieu Scheltienne`_) +- The string and HTML representation of :class:`mne.preprocessing.ICA` reported incorrect values for the explained variance. This information has been removed from the representations, and should instead be retrieved via the new :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio` method (:gh:`11141` by `Richard Höchenberger`_) API changes ~~~~~~~~~~~ diff --git a/mne/html_templates/repr/ica.html.jinja b/mne/html_templates/repr/ica.html.jinja index 0fb0e053b4e..080c3cd5e95 100644 --- a/mne/html_templates/repr/ica.html.jinja +++ b/mne/html_templates/repr/ica.html.jinja @@ -12,10 +12,6 @@ ICA components {{ n_components }} - - Explained variance - {{ (explained_variance * 100) | round(1) }} % - Available PCA components {{ n_pca_components }} diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 16e7ab63b44..e9f7dba4d0d 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -8,11 +8,13 @@ from inspect import isfunction from collections import namedtuple +from collections.abc import Sequence from copy import deepcopy from numbers import Integral from time import time from dataclasses import dataclass from typing import Optional, List +import warnings import math import json @@ -452,7 +454,6 @@ class _InfosForRepr: fit_n_samples: Optional[int] fit_n_components: Optional[int] fit_n_pca_components: Optional[int] - fit_explained_variance: Optional[float] ch_types: List[str] excludes: List[str] @@ -470,11 +471,6 @@ class _InfosForRepr: fit_n_pca_components = getattr(self, 'pca_components_', None) if fit_n_pca_components is not None: fit_n_pca_components = len(self.pca_components_) - fit_explained_variance = getattr(self, 'pca_explained_variance_', None) - if fit_explained_variance is not None: - abs_vars = self.pca_explained_variance_ - rel_vars = abs_vars / abs_vars.sum() - fit_explained_variance = rel_vars[:fit_n_components].sum() if self.info is not None: ch_types = [c for c in _DATA_CH_TYPES_SPLIT if c in self] @@ -493,7 +489,6 @@ class _InfosForRepr: fit_n_samples=fit_n_samples, fit_n_components=fit_n_components, fit_n_pca_components=fit_n_pca_components, - fit_explained_variance=fit_explained_variance, ch_types=ch_types, excludes=excludes ) @@ -511,8 +506,6 @@ def __repr__(self): f' (fit in {infos.fit_n_iter} iterations on ' f'{infos.fit_n_samples} samples), ' f'{infos.fit_n_components} ICA components ' - f'explaining {round(infos.fit_explained_variance * 100, 1)} % ' - f'of variance ' f'({infos.fit_n_pca_components} PCA components available), ' f'channel types: {", ".join(infos.ch_types)}, ' f'{len(infos.excludes) or "no"} sources marked for exclusion' @@ -531,7 +524,6 @@ def _repr_html_(self): n_samples=infos.fit_n_samples, n_components=infos.fit_n_components, n_pca_components=infos.fit_n_pca_components, - explained_variance=infos.fit_explained_variance, ch_types=infos.ch_types, excludes=infos.excludes ) @@ -962,6 +954,101 @@ def get_components(self): return np.dot(self.mixing_matrix_[:, :self.n_components_].T, self.pca_components_[:self.n_components_]).T + def get_explained_variance_ratio( + self, inst, *, components=None + ): + """Get the proportion of data variance explained by ICA components. + + A value similar to EEGLAB's ``pvaf`` (percent variance accounted for) + will be calculated for the specified component(s). + + Parameters + ---------- + inst : mne.io.BaseRaw | mne.BaseEpochs | mne.Evoked + The uncleaned data. + components : array-like of int | int | None + The component(s) for which to do the calculation. If more than one + component is specified, explained variance will be calculated + jointly across all supplied components. If ``None`` (default), uses + all available components. + + Returns + ------- + float + The fraction of variance in ``inst`` that can be explained by the + ICA components. + + Notes + ----- + Since ICA components cannot be assumed to be aligned orthogonally, the + sum of the proportion of variance explained by all components may not + be equal to 1. In certain edge cases, the proportion of variance + explained by a component may even be negative. + + .. versionadded:: 1.1 + """ + _validate_type( + item=inst, types=(BaseRaw, BaseEpochs, Evoked), + item_name='inst' + ) + _validate_type( + item=components, types=(None, 'int-like', Sequence, np.ndarray), + item_name='components' + ) + if isinstance(components, (Sequence, np.ndarray)): + for item in components: + _validate_type( + item=item, types='int-like', + item_name='Elements of "components"' + ) + + if self.current_fit == 'unfitted': + raise ValueError('ICA must be fitted first.') + + if components is None: + components = range(self.n_components_) + + # The algorithm implemented below should be equivalent to + # https://sccn.ucsd.edu/pipermail/eeglablist/2014/009134.html + # + # Reconstruct ("back-project") the data using only the specified ICA + # components. Don't make use of potential "spare" PCA components in + # this process – we're only interested in the contribution of the ICA + # components! + kwargs = dict( + inst=inst.copy(), + include=[components], + exclude=[], + n_pca_components=0, + verbose=False, + ) + if ( + isinstance(inst, (BaseEpochs, Evoked)) and + inst.baseline is not None + ): + # Don't warn if data was baseline-corrected. + with warnings.catch_warnings(): + warnings.filterwarnings( + action='ignore', + message='The data.*was baseline-corrected', + category=RuntimeWarning + ) + inst_recon = self.apply(**kwargs) + else: + inst_recon = self.apply(**kwargs) + + data_recon = inst_recon.get_data(picks=self.ch_names) + data_orig = inst.get_data(picks=self.ch_names) + data_diff = data_orig - data_recon + + # To estimate the data variance, we first compute the variance across + # channels at each time point, and then we average these variances. + mean_var_diff = data_diff.var(axis=0).mean() + mean_var_orig = data_orig.var(axis=0).mean() + + var_explained_ratio = 1 - mean_var_diff / mean_var_orig + return var_explained_ratio + def get_sources(self, inst, add_channels=None, start=None, stop=None): """Estimate sources given the unmixing matrix. @@ -2247,6 +2334,8 @@ def _find_sources(sources, target, score_func): def _ica_explained_variance(ica, inst, normalize=False): """Check variance accounted for by each component in supplied data. + This function is only used for sorting the components. + Parameters ---------- ica : ICA diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 75c61c14236..c9ab58cefe7 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -490,8 +490,6 @@ def test_ica_core(method, n_components, noise_cov, n_pca_components, assert 'raw data decomposition' in repr_ assert f'{ica.n_components_} ICA components' in repr_ assert 'Available PCA components' in repr_html_ - assert 'Explained variance' in repr_html_ - assert ('mag' in ica) # should now work without error # test re-fit @@ -949,6 +947,37 @@ def f(x, y): ica.fit(raw_, picks=picks, reject_by_annotation=True) +@requires_sklearn +def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): + """Test ICA.get_explained_variance_ratio().""" + _, epochs, _ = short_raw_epochs + ica = ICA(max_iter=1) + with pytest.warns(RuntimeWarning, match='were baseline-corrected'): + ica.fit(epochs) + + # int + explained_var_comp_0 = ica.get_explained_variance_ratio( + epochs, components=0 + ) + # list of int, single element + explained_var_comp_1 = ica.get_explained_variance_ratio( + epochs, components=[1] + ) + # list of int, multiple elements + explained_var_comps_01 = ica.get_explained_variance_ratio( + epochs, components=[0, 1] + ) + # None, i.e., all components + explained_var_comps_all = ica.get_explained_variance_ratio( + epochs, components=None + ) + + assert round(explained_var_comp_0, 4) == 0.0229 + assert round(explained_var_comp_1, 4) == 0.0231 + assert round(explained_var_comps_01, 4) == 0.0459 + assert explained_var_comps_all == 1 + + @requires_sklearn @pytest.mark.slowtest @pytest.mark.parametrize('method, cov', [ diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index ae04d9cf4fe..73804a191ab 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -264,6 +264,20 @@ # when creating epoched data in the :ref:`tut-overview` tutorial). # # Now we can examine the ICs to see what they captured. +# +# Using :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio`, we can +# retrieve the amount fraction of variance in the original data that is +# explained by our ICA components: + +var_explained_all = ica.get_explained_variance_ratio(filt_raw) +var_explained_first = ica.get_explained_variance_ratio( + filt_raw, + components=[0] +) +print(f'Var. explained by all components: {round(100*var_explained_all)}%') +print(f'Var. explained by first component: {round(100*var_explained_first)}%') + +# %% # `~mne.preprocessing.ICA.plot_sources` will show the time series of the # ICs. Note that in our call to `~mne.preprocessing.ICA.plot_sources` we # can use the original, unfiltered `~mne.io.Raw` object. A helpful tip is that From 47f63a7221a84d6673f8f6ca238485406f5c9286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 12 Sep 2022 19:32:48 +0200 Subject: [PATCH 02/15] More tests --- mne/preprocessing/tests/test_ica.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index c9ab58cefe7..f723d3103a7 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -950,7 +950,7 @@ def f(x, y): @requires_sklearn def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): """Test ICA.get_explained_variance_ratio().""" - _, epochs, _ = short_raw_epochs + raw, epochs, _ = short_raw_epochs ica = ICA(max_iter=1) with pytest.warns(RuntimeWarning, match='were baseline-corrected'): ica.fit(epochs) @@ -977,6 +977,11 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): assert round(explained_var_comps_01, 4) == 0.0459 assert explained_var_comps_all == 1 + # Test wih Raw + ica.get_explained_variance_ratio(raw) + # Test wih Evoked + ica.get_explained_variance_ratio(epochs.average()) + @requires_sklearn @pytest.mark.slowtest From 905be6828581a54258883e45c4c748e767dae2c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 12 Sep 2022 19:34:57 +0200 Subject: [PATCH 03/15] Test without baseline correction --- mne/preprocessing/tests/test_ica.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index f723d3103a7..d7cced1f9de 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -980,7 +980,11 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): # Test wih Raw ica.get_explained_variance_ratio(raw) # Test wih Evoked - ica.get_explained_variance_ratio(epochs.average()) + evoked = epochs.average() + ica.get_explained_variance_ratio(evoked) + # Test wih Evoked without baseline correction + evoked.baseline = None + ica.get_explained_variance_ratio(evoked) @requires_sklearn From 5d96641ca7ed9286d01824f3f2c0eb4cae655758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 12 Sep 2022 20:45:04 +0200 Subject: [PATCH 04/15] Fix typos in comments [skip azp][skip actions] --- mne/preprocessing/tests/test_ica.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index d7cced1f9de..71dbae54d3b 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -977,12 +977,12 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): assert round(explained_var_comps_01, 4) == 0.0459 assert explained_var_comps_all == 1 - # Test wih Raw + # Test Raw ica.get_explained_variance_ratio(raw) - # Test wih Evoked + # Test Evoked evoked = epochs.average() ica.get_explained_variance_ratio(evoked) - # Test wih Evoked without baseline correction + # Test Evoked without baseline correction evoked.baseline = None ica.get_explained_variance_ratio(evoked) From 08544d53eb884ab5bb0474a57b397308f9f5678b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 12 Sep 2022 20:45:44 +0200 Subject: [PATCH 05/15] Fix versionadded [skip azp][skip actions] --- mne/preprocessing/ica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index e9f7dba4d0d..13b0947b41a 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -985,7 +985,7 @@ def get_explained_variance_ratio( be equal to 1. In certain edge cases, the proportion of variance explained by a component may even be negative. - .. versionadded:: 1.1 + .. versionadded:: 1.2 """ _validate_type( item=inst, types=(BaseRaw, BaseEpochs, Evoked), From c966a43f7bfe76cd06d6069d3db7f21b50a3e7d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Mon, 12 Sep 2022 21:54:42 +0200 Subject: [PATCH 06/15] More test coverage --- mne/preprocessing/tests/test_ica.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 71dbae54d3b..14a16745e2c 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -952,6 +952,11 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): """Test ICA.get_explained_variance_ratio().""" raw, epochs, _ = short_raw_epochs ica = ICA(max_iter=1) + + # Unfitted ICA should raise an exception + with pytest.raises(ValueError, match='ICA must be fitted first'): + ica.get_explained_variance_ratio(epochs) + with pytest.warns(RuntimeWarning, match='were baseline-corrected'): ica.fit(epochs) From d63a5026ac50fdbf85c52942db8d302f931dd0b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 14:01:03 +0200 Subject: [PATCH 07/15] Update docstring --- mne/preprocessing/ica.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 13b0947b41a..3314c2663f0 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -959,9 +959,6 @@ def get_explained_variance_ratio( ): """Get the proportion of data variance explained by ICA components. - A value similar to EEGLAB's ``pvaf`` (percent variance accounted for) - will be calculated for the specified component(s). - Parameters ---------- inst : mne.io.BaseRaw | mne.BaseEpochs | mne.Evoked @@ -980,9 +977,12 @@ def get_explained_variance_ratio( Notes ----- + A value similar to EEGLAB's ``pvaf`` (percent variance accounted for) + will be calculated for the specified component(s). + Since ICA components cannot be assumed to be aligned orthogonally, the sum of the proportion of variance explained by all components may not - be equal to 1. In certain edge cases, the proportion of variance + be equal to 1. In certain situations, the proportion of variance explained by a component may even be negative. .. versionadded:: 1.2 From 1aca2b2e1e0bf6244f9f79181475c16ef5405bfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 15:06:44 +0200 Subject: [PATCH 08/15] Process channel types separately --- mne/preprocessing/ica.py | 57 +++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 3314c2663f0..740fe91b13d 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -955,7 +955,7 @@ def get_components(self): self.pca_components_[:self.n_components_]).T def get_explained_variance_ratio( - self, inst, *, components=None + self, inst, *, components=None, ch_type=None ): """Get the proportion of data variance explained by ICA components. @@ -968,12 +968,19 @@ def get_explained_variance_ratio( component is specified, explained variance will be calculated jointly across all supplied components. If ``None`` (default), uses all available components. + ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | array-like of str | None + The channel type(s) to include in the calculation. If None, all + available channel types channel types will be used. Note that the + value of this parameter may change the return type (float or + dictionary). Returns ------- - float + float | dict, str -> float The fraction of variance in ``inst`` that can be explained by the - ICA components. + ICA components. If only a single ``ch_type`` was given, a float + will be returned. Otherwise, a dictionary with channel types as + keys and explained variance ratios as values. Notes ----- @@ -986,7 +993,7 @@ def get_explained_variance_ratio( explained by a component may even be negative. .. versionadded:: 1.2 - """ + """ # noqa: E501 _validate_type( item=inst, types=(BaseRaw, BaseEpochs, Evoked), item_name='inst' @@ -1005,9 +1012,47 @@ def get_explained_variance_ratio( if self.current_fit == 'unfitted': raise ValueError('ICA must be fitted first.') + _validate_type( + item=ch_type, types=(Sequence, np.ndarray, str, None), + item_name='ch_type' + ) + allowed_ch_types = ('mag', 'grad', 'planar1', 'planar2', 'eeg') + if isinstance(ch_type, str): + ch_types = [ch_type] + elif ch_type is None: + ch_types = inst.get_channel_types(unique=True, only_data_chs=True) + else: + assert isinstance(ch_type, (Sequence, np.ndarray)) + ch_types = ch_type + + assert len(ch_types) >= 1 + for ch_type in ch_types: + if ch_type not in allowed_ch_types: + raise ValueError( + f'You requested operation on the channel type ' + f'"{ch_type}", but only the following channel types are ' + f'supported: {", ".join(allowed_ch_types)}' + ) + del ch_type + if components is None: components = range(self.n_components_) + explained_var_ratios = [ + self._get_explained_variance_ratio_one_ch_type( + inst=inst, components=components, ch_type=ch_type + ) for ch_type in ch_types + ] + if len(ch_types) == 1: + result = explained_var_ratios[0] + else: + result = dict(zip(ch_types, explained_var_ratios)) + + return result + + def _get_explained_variance_ratio_one_ch_type( + self, *, inst, components, ch_type + ): # The algorithm implemented below should be equivalent to # https://sccn.ucsd.edu/pipermail/eeglablist/2014/009134.html # @@ -1037,8 +1082,8 @@ def get_explained_variance_ratio( else: inst_recon = self.apply(**kwargs) - data_recon = inst_recon.get_data(picks=self.ch_names) - data_orig = inst.get_data(picks=self.ch_names) + data_recon = inst_recon.get_data(picks=ch_type) + data_orig = inst.get_data(picks=ch_type) data_diff = data_orig - data_recon # To estimate the data variance, we first compute the variance across From bbcd19d1aedf72b79cca666e0be1c53dd83d6251 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 15:10:11 +0200 Subject: [PATCH 09/15] Restructure --- mne/preprocessing/ica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 740fe91b13d..efaf8c161d9 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1016,7 +1016,6 @@ def get_explained_variance_ratio( item=ch_type, types=(Sequence, np.ndarray, str, None), item_name='ch_type' ) - allowed_ch_types = ('mag', 'grad', 'planar1', 'planar2', 'eeg') if isinstance(ch_type, str): ch_types = [ch_type] elif ch_type is None: @@ -1026,6 +1025,7 @@ def get_explained_variance_ratio( ch_types = ch_type assert len(ch_types) >= 1 + allowed_ch_types = ('mag', 'grad', 'planar1', 'planar2', 'eeg') for ch_type in ch_types: if ch_type not in allowed_ch_types: raise ValueError( From b499dcf699fb04c6b02669c5d1df5c8dfcae1cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 15:51:19 +0200 Subject: [PATCH 10/15] Docstring --- mne/preprocessing/ica.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index efaf8c161d9..444cef1609d 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -969,14 +969,13 @@ def get_explained_variance_ratio( jointly across all supplied components. If ``None`` (default), uses all available components. ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | array-like of str | None - The channel type(s) to include in the calculation. If None, all - available channel types channel types will be used. Note that the - value of this parameter may change the return type (float or - dictionary). + The channel type(s) to include in the calculation. If ``None``, all + available channel types will be used. Note that the value of this + parameter may change the return type (float or dictionary). Returns ------- - float | dict, str -> float + float | dict [str, float] The fraction of variance in ``inst`` that can be explained by the ICA components. If only a single ``ch_type`` was given, a float will be returned. Otherwise, a dictionary with channel types as From eda143ac51162995d98ac13b4bc6ba790d64d620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 16:02:50 +0200 Subject: [PATCH 11/15] Update tutorial --- .../40_artifact_correction_ica.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 73804a191ab..b6a7d72ec1b 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -213,8 +213,8 @@ filt_raw = raw.copy().filter(l_freq=1., h_freq=None) # %% -# Fitting and plotting the ICA solution -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Fitting ICA +# ~~~~~~~~~~~ # # .. admonition:: Ignoring the time domain # :class: sidebar hint @@ -262,20 +262,35 @@ # speed-up) and ``reject`` (for providing a rejection dictionary for maximum # acceptable peak-to-peak amplitudes for each channel type, just like we used # when creating epoched data in the :ref:`tut-overview` tutorial). -# + +# %% +# Looking at the ICA solution +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Now we can examine the ICs to see what they captured. # # Using :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio`, we can -# retrieve the amount fraction of variance in the original data that is -# explained by our ICA components: +# retrieve the fraction of variance in the original data that is explained by +# our ICA components: + +var_explained = ica.get_explained_variance_ratio(filt_raw) +print(f'Variance explained by all components: {var_explained}') -var_explained_all = ica.get_explained_variance_ratio(filt_raw) +# %% +# The values were calculated for all ICA components jointly, but separately for +# each channel type (here: magnetometers and EEG). +# +# We can also explicitly request for which component(s) and channel type(s) to +# perform the computation: var_explained_first = ica.get_explained_variance_ratio( filt_raw, - components=[0] + components=[0], + ch_type='eeg' +) +# This time, print as percentage. +print( + f'Variance of EEG signal explained by first component: ' + f'{round(100 * var_explained_first)}%' ) -print(f'Var. explained by all components: {round(100*var_explained_all)}%') -print(f'Var. explained by first component: {round(100*var_explained_first)}%') # %% # `~mne.preprocessing.ICA.plot_sources` will show the time series of the From 5cb3cfd6059c8ac7021b9fb4c36af5009770e044 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 16:23:42 +0200 Subject: [PATCH 12/15] Update tests --- mne/preprocessing/tests/test_ica.py | 55 ++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 14a16745e2c..73efe65ff1b 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -960,27 +960,66 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): with pytest.warns(RuntimeWarning, match='were baseline-corrected'): ica.fit(epochs) - # int + # components = int, ch_type = None explained_var_comp_0 = ica.get_explained_variance_ratio( epochs, components=0 ) - # list of int, single element + # components = int, ch_type = str + explained_var_comp_0_eeg = ica.get_explained_variance_ratio( + epochs, components=0, ch_type='eeg' + ) + # components = int, ch_type = list of str + explained_var_comp_0_eeg_mag = ica.get_explained_variance_ratio( + epochs, components=0, ch_type=['eeg', 'mag'] + ) + # components = list of int, single element, ch_type = None explained_var_comp_1 = ica.get_explained_variance_ratio( epochs, components=[1] ) - # list of int, multiple elements + # components = list of int, multiple elements, ch_type = None explained_var_comps_01 = ica.get_explained_variance_ratio( epochs, components=[0, 1] ) - # None, i.e., all components + # components = None, i.e., all components, ch_type = None explained_var_comps_all = ica.get_explained_variance_ratio( epochs, components=None ) - assert round(explained_var_comp_0, 4) == 0.0229 - assert round(explained_var_comp_1, 4) == 0.0231 - assert round(explained_var_comps_01, 4) == 0.0459 - assert explained_var_comps_all == 1 + assert 'grad' in explained_var_comp_0 + assert 'mag' in explained_var_comp_0 + assert 'eeg' in explained_var_comp_0 + + assert isinstance(explained_var_comp_0_eeg, float) + + assert 'mag' in explained_var_comp_0_eeg_mag + assert 'eeg' in explained_var_comp_0_eeg_mag + assert 'grad' not in explained_var_comp_0_eeg_mag + + assert round(explained_var_comp_0['grad'], 4) == 0.1784 + assert round(explained_var_comp_0['mag'], 4) == 0.0259 + assert round(explained_var_comp_0['eeg'], 4) == 0.0229 + + assert np.isclose( + explained_var_comp_0['eeg'], + explained_var_comp_0_eeg + ) + assert np.isclose( + explained_var_comp_0['mag'], + explained_var_comp_0_eeg_mag['mag'] + ) + assert np.isclose( + explained_var_comp_0['eeg'], + explained_var_comp_0_eeg_mag['eeg'] + ) + + assert round(explained_var_comp_1['eeg'], 4) == 0.0231 + assert round(explained_var_comps_01['eeg'], 4) == 0.0459 + assert ( + explained_var_comps_all['grad'] == + explained_var_comps_all['mag'] == + explained_var_comps_all['eeg'] == + 1 + ) # Test Raw ica.get_explained_variance_ratio(raw) From bd2ccc293c87fcf32483958dc7555037ba3f2e74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 21:35:51 +0200 Subject: [PATCH 13/15] Always return a dict --- mne/preprocessing/ica.py | 17 ++++++----------- mne/preprocessing/tests/test_ica.py | 5 +++-- .../preprocessing/40_artifact_correction_ica.py | 15 +++++++-------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 444cef1609d..20b4c97f45e 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -970,16 +970,15 @@ def get_explained_variance_ratio( all available components. ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | array-like of str | None The channel type(s) to include in the calculation. If ``None``, all - available channel types will be used. Note that the value of this - parameter may change the return type (float or dictionary). + available channel types will be used. Returns ------- - float | dict [str, float] + dict (str, float) The fraction of variance in ``inst`` that can be explained by the - ICA components. If only a single ``ch_type`` was given, a float - will be returned. Otherwise, a dictionary with channel types as - keys and explained variance ratios as values. + ICA components, calculated separately for each channel type. + Dictionary keys are the channel types, and corresponding explained + variance ratios are the values. Notes ----- @@ -1042,11 +1041,7 @@ def get_explained_variance_ratio( inst=inst, components=components, ch_type=ch_type ) for ch_type in ch_types ] - if len(ch_types) == 1: - result = explained_var_ratios[0] - else: - result = dict(zip(ch_types, explained_var_ratios)) - + result = dict(zip(ch_types, explained_var_ratios)) return result def _get_explained_variance_ratio_one_ch_type( diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 73efe65ff1b..6bc7b393883 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -989,7 +989,8 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): assert 'mag' in explained_var_comp_0 assert 'eeg' in explained_var_comp_0 - assert isinstance(explained_var_comp_0_eeg, float) + assert len(explained_var_comp_0_eeg) == 1 + assert 'eeg' in explained_var_comp_0_eeg assert 'mag' in explained_var_comp_0_eeg_mag assert 'eeg' in explained_var_comp_0_eeg_mag @@ -1001,7 +1002,7 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): assert np.isclose( explained_var_comp_0['eeg'], - explained_var_comp_0_eeg + explained_var_comp_0_eeg['eeg'] ) assert np.isclose( explained_var_comp_0['mag'], diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index b6a7d72ec1b..f3cf8d695de 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -270,10 +270,11 @@ # # Using :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio`, we can # retrieve the fraction of variance in the original data that is explained by -# our ICA components: +# our ICA components in the form of a dictionary: -var_explained = ica.get_explained_variance_ratio(filt_raw) -print(f'Variance explained by all components: {var_explained}') +explained_var = ica.get_explained_variance_ratio(filt_raw) +for channel_type, ratio in explained_var.items(): + print(f'{channel_type} variance explained by all components: {ratio}') # %% # The values were calculated for all ICA components jointly, but separately for @@ -281,16 +282,14 @@ # # We can also explicitly request for which component(s) and channel type(s) to # perform the computation: -var_explained_first = ica.get_explained_variance_ratio( +explained_var = ica.get_explained_variance_ratio( filt_raw, components=[0], ch_type='eeg' ) # This time, print as percentage. -print( - f'Variance of EEG signal explained by first component: ' - f'{round(100 * var_explained_first)}%' -) +ratio_percent = round(100 * explained_var['eeg']) +print(f'Variance of EEG signal explained by first component: {ratio_percent}%') # %% # `~mne.preprocessing.ICA.plot_sources` will show the time series of the From 33e3a283db8f8a820add6cbbaf73f2947e913b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 21:49:46 +0200 Subject: [PATCH 14/15] More test coverage and better type validation messages --- mne/preprocessing/ica.py | 11 ++++++----- mne/preprocessing/tests/test_ica.py | 4 ++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 20b4c97f45e..b9c42b503c4 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -992,13 +992,16 @@ def get_explained_variance_ratio( .. versionadded:: 1.2 """ # noqa: E501 + if self.current_fit == 'unfitted': + raise ValueError('ICA must be fitted first.') + _validate_type( item=inst, types=(BaseRaw, BaseEpochs, Evoked), item_name='inst' ) _validate_type( item=components, types=(None, 'int-like', Sequence, np.ndarray), - item_name='components' + item_name='components', type_name='int, array-like of int, or None' ) if isinstance(components, (Sequence, np.ndarray)): for item in components: @@ -1007,12 +1010,9 @@ def get_explained_variance_ratio( item_name='Elements of "components"' ) - if self.current_fit == 'unfitted': - raise ValueError('ICA must be fitted first.') - _validate_type( item=ch_type, types=(Sequence, np.ndarray, str, None), - item_name='ch_type' + item_name='ch_type', type_name='str, array-like of str, or None' ) if isinstance(ch_type, str): ch_types = [ch_type] @@ -1033,6 +1033,7 @@ def get_explained_variance_ratio( ) del ch_type + # Input data validation ends here if components is None: components = range(self.n_components_) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 6bc7b393883..ad203756211 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -1031,6 +1031,10 @@ def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): evoked.baseline = None ica.get_explained_variance_ratio(evoked) + # Test invalid ch_type + with pytest.raises(ValueError, match='only the following channel types'): + ica.get_explained_variance_ratio(raw, ch_type='foobar') + @requires_sklearn @pytest.mark.slowtest From 0857ceed9ba0347acb843647b0c81ba0f0320887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Tue, 13 Sep 2022 22:42:27 +0200 Subject: [PATCH 15/15] Be more explicit [skip azp][skip actions] --- .../40_artifact_correction_ica.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index f3cf8d695de..8a95b0ab471 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -272,9 +272,12 @@ # retrieve the fraction of variance in the original data that is explained by # our ICA components in the form of a dictionary: -explained_var = ica.get_explained_variance_ratio(filt_raw) -for channel_type, ratio in explained_var.items(): - print(f'{channel_type} variance explained by all components: {ratio}') +explained_var_ratio = ica.get_explained_variance_ratio(filt_raw) +for channel_type, ratio in explained_var_ratio.items(): + print( + f'Fraction of {channel_type} variance explained by all components: ' + f'{ratio}' + ) # %% # The values were calculated for all ICA components jointly, but separately for @@ -282,14 +285,17 @@ # # We can also explicitly request for which component(s) and channel type(s) to # perform the computation: -explained_var = ica.get_explained_variance_ratio( +explained_var_ratio = ica.get_explained_variance_ratio( filt_raw, components=[0], ch_type='eeg' ) # This time, print as percentage. -ratio_percent = round(100 * explained_var['eeg']) -print(f'Variance of EEG signal explained by first component: {ratio_percent}%') +ratio_percent = round(100 * explained_var_ratio['eeg']) +print( + f'Fraction of variance in EEG signal explained by first component: ' + f'{ratio_percent}%' +) # %% # `~mne.preprocessing.ICA.plot_sources` will show the time series of the