From 679423b880e022d0584f7247a336f8e44f3fede4 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Sat, 15 Dec 2012 16:01:17 +0100 Subject: [PATCH] ENH : cleanup ICA code --- mne/fiff/constants.py | 10 +- mne/preprocessing/ica.py | 286 +++++++++++----------------- mne/preprocessing/tests/test_ica.py | 38 ++-- mne/viz.py | 8 +- 4 files changed, 145 insertions(+), 197 deletions(-) diff --git a/mne/fiff/constants.py b/mne/fiff/constants.py index 55103587a90..eb5282c600d 100644 --- a/mne/fiff/constants.py +++ b/mne/fiff/constants.py @@ -343,12 +343,10 @@ def __init__(self, **kwargs): FIFF.FIFF_MNE_ICA_INTERFACE_PARAMS = 3601 # ICA interface parameters FIFF.FIFF_MNE_ICA_CHANNEL_NAMES = 3602 # ICA channel names FIFF.FIFF_MNE_ICA_WHITENER = 3603 # ICA whitener -FIFF.FIFF_MNE_ICA_PCA_PARAMS = 3604 # _PCA parameters -FIFF.FIFF_MNE_ICA_PCA_COMPONENTS = 3605 # _PCA components -FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR = 3606 # _PCA explained variance -FIFF.FIFF_MNE_ICA_PCA_MEAN = 3607 # _PCA mean -FIFF.FIFF_MNE_ICA_PARAMS = 3608 # _ICA parameters -FIFF.FIFF_MNE_ICA_MATRIX = 3609 # _ICA unmixing matrix +FIFF.FIFF_MNE_ICA_PCA_COMPONENTS = 3604 # PCA components +FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR = 3605 # PCA explained variance +FIFF.FIFF_MNE_ICA_PCA_MEAN = 3606 # PCA mean +FIFF.FIFF_MNE_ICA_MATRIX = 3607 # ICA unmixing matrix # # Fiff values associated with MNE computations # diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 99532265d61..eaf9eaa1572 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -8,7 +8,6 @@ import inspect import warnings from inspect import getargspec, isfunction -from collections import namedtuple as nt import os import logging @@ -66,16 +65,6 @@ def _make_xy_sfunc(func, ndim_output=False): __all__ = ['ICA', 'ica_find_ecg_events', 'ica_find_eog_events', 'score_funcs', 'read_ica'] -PARAMS_ICA = nt('ICA_parameters', ['fun_args', 'fun_prime', 'algorithm', - 'max_iter', 'random_state', 'n_components', 'tol', - 'fun', 'w_init', 'whiten']) - -PARAMS_PCA = nt('PCA_parameters', ['random_state', 'copy', 'n_components', - 'iterated_power', 'whiten']) - -ATTRIBUTES_PCA = nt('PCA_attributes', ['components_', 'mean_', 'explained_variance_', - 'explained_variance_ratio_']) - class ICA(object): """M/EEG signal decomposition using Independent Component Analysis (ICA) @@ -120,64 +109,53 @@ class ICA(object): Attributes ---------- - last_fit : str - Flag informing about which type was last fit. + current_fit : str + Flag informing about which data type (raw or epochs) was used for + the fit. ch_names : list-like Channel names resulting from initial picking. - n_components : int - The number of components used for ICA decomposition. max_n_components : int The number of PCA dimensions computed. verbose : bool, str, int, or None See above. - mixing_matrix : None | ndarray + pca_components_ : ndarray + If fit, the PCA components + pca_mean_ : ndarray + If fit, the mean vector used to center the data before doing the PCA. + pca_explained_variance_ : ndarray + If fit, the variance explained by each PCA component + n_ica_components_ : int + The number of components used for ICA decomposition. + mixing_matrix_ : None | ndarray If fit, the mixing matrix to restore observed data, else None. - unmixing_matrix : None | ndarray + unmixing_matrix_ : None | ndarray If fit, the matrix to unmix observed data, else None. """ @verbose def __init__(self, n_components, max_n_components=100, noise_cov=None, random_state=None, algorithm='parallel', fun='logcosh', fun_args=None, verbose=None): - try: - from sklearn.decomposition import FastICA # to avoid strong dep. - except ImportError: - raise Exception('the scikit-learn package is missing and ' - 'required for ICA') self.noise_cov = noise_cov - # sklearn < 0.11 does not support random_state argument for FastICA - kwargs = {'algorithm': algorithm, 'fun': fun, 'fun_args': fun_args} - - if random_state is not None: - aspec = inspect.getargspec(FastICA.__init__) - if 'random_state' not in aspec.args: - warnings.warn('random_state argument ignored, update ' - 'scikit-learn to version 0.11 or newer') - else: - kwargs['random_state'] = random_state - if max_n_components is not None and n_components > max_n_components: raise ValueError('n_components must be smaller than ' 'max_n_components') - if isinstance(n_components, float): - if not 0 < n_components <= 1: - raise ValueError('For selecting ICA components by the ' - 'explained variance of PCA components the' - ' float value must be between 0.0 and 1.0 ') - self._explained_var = n_components - logger.info('Selecting pca_components via explained variance.') - else: - self._explained_var = 1.1 - logger.info('Selecting pca_components directly.') + if isinstance(n_components, float) \ + and not 0 < n_components <= 1: + raise ValueError('For selecting ICA components by the ' + 'explained variance of PCA components the' + ' float value must be between 0.0 and 1.0 ') - self._ica = FastICA(**kwargs) self.current_fit = 'unfitted' self.verbose = verbose self.n_components = n_components self.max_n_components = max_n_components self.ch_names = None + self.random_state = random_state + self.algorithm = algorithm + self.fun = fun + self.fun_args = fun_args def __repr__(self): s = 'ICA ' @@ -318,7 +296,7 @@ def get_sources_raw(self, raw, start=None, stop=None): sources : array, shape = (n_components, n_times) The ICA sources time series. """ - if self.mixing_matrix is None: + if not hasattr(self, 'mixing_matrix_'): raise RuntimeError('No fit available. Please first fit ICA ' 'decomposition.') @@ -328,8 +306,8 @@ def _get_sources_raw(self, raw, start, stop): picks = [raw.ch_names.index(k) for k in self.ch_names] data, _ = self._pre_whiten(raw[picks, start:stop][0], raw.info, picks) pca_data = self._transform_pca(data.T) - raw_sources = self._transform_ica(pca_data[:, self._comp_idx]).T - + n_ica_components = self.n_ica_components_ + raw_sources = self._transform_ica(pca_data[:, :n_ica_components]).T return raw_sources, pca_data def get_sources_epochs(self, epochs, concatenate=False): @@ -347,7 +325,7 @@ def get_sources_epochs(self, epochs, concatenate=False): epochs_sources : ndarray of shape (n_epochs, n_sources, n_times) The sources for each epoch """ - if self.mixing_matrix is None: + if not hasattr(self, 'mixing_matrix_'): raise RuntimeError('No fit available. Please first fit ICA ' 'decomposition.') @@ -370,7 +348,7 @@ def _get_sources_epochs(self, epochs, concatenate): epochs.info, picks) pca_data = self._transform_pca(data.T) - sources = self._transform_ica(pca_data[:, self._comp_idx]).T + sources = self._transform_ica(pca_data[:, :self.n_ica_components_]).T sources = np.array(np.split(sources, len(epochs.events), 1)) if concatenate: @@ -386,7 +364,6 @@ def save(self, fname): ---------- fname : str The absolute path of the file name to save the ICA session into. - """ if self.current_fit == 'unfitted': raise RuntimeError('No fit available. Please first fit ICA ' @@ -448,7 +425,7 @@ def export_sources(self, raw, picks=None, start=None, stop=None): # set channel names and info ch_names = out.info['ch_names'] = [] ch_info = out.info['chs'] = [] - for i in xrange(self.n_components): + for i in xrange(self.n_ica_components_): ch_names.append('ICA %03d' % (i + 1)) ch_info.append(dict(ch_name='ICA %03d' % (i + 1), cal=1, logno=i + 1, coil_type=FIFF.FIFFV_COIL_NONE, @@ -464,7 +441,7 @@ def export_sources(self, raw, picks=None, start=None, stop=None): ch_info += [raw.info['chs'][k] for k in picks] # update number of channels - out.info['nchan'] = len(picks) + self.n_components + out.info['nchan'] = len(picks) + self.n_ica_components_ return out @@ -500,7 +477,6 @@ def plot_sources_raw(self, raw, order=None, start=None, stop=None, ------- fig : instance of pyplot.Figure """ - sources = self.get_sources_raw(raw, start=start, stop=stop) if order is not None: @@ -569,10 +545,7 @@ def plot_sources_epochs(self, epochs, epoch_idx=None, order=None, fig = plot_ica_panel(sources[epoch_idx], start=start, stop=stop, n_components=n_components, source_idx=source_idx, - ncol=ncol, nrow=nrow) - if show: - import matplotlib.pylab as pl - pl.show() + ncol=ncol, nrow=nrow, show=show) return fig @@ -775,24 +748,6 @@ def pick_sources_epochs(self, epochs, include=None, exclude=None, return epochs - @property - def mixing_matrix(self): - """The ICA mixing matrix""" - if hasattr(self, '_unmixing'): - out = linalg.pinv(self._unmixing).T - else: - out = None - return out - - @property - def unmixing_matrix(self): - """The ICA mixing matrix""" - if hasattr(self, '_unmixing'): - out = self._unmixing - else: - out = None - return out - def _pre_whiten(self, data, info, picks): """Helper function""" if self.noise_cov is None: # use standardization as whitener @@ -812,7 +767,7 @@ def _pre_whiten(self, data, info, picks): return data, pre_whitener - def _decompose(self, data, max_n_components, caller): + def _decompose(self, data, max_n_components, fit_type): """ Helper Function """ from sklearn.decomposition import RandomizedPCA @@ -830,38 +785,56 @@ def _decompose(self, data, max_n_components, caller): pca = RandomizedPCA(**kwargs) pca_data = pca.fit_transform(data.T) - if self._explained_var > 1.0: + if isinstance(self.n_components, float): + logger.info('Selecting pca_components via explained variance.') + n_ica_components_ = np.sum(pca.explained_variance_ratio_.cumsum() + < self.n_components) + to_ica = pca_data[:, :n_ica_components_] + else: + logger.info('Selecting pca_components directly.') if self.n_components is not None: # normal n case - self._comp_idx = np.arange(self.n_components) - to_ica = pca_data[:, self._comp_idx] + to_ica = pca_data[:, :self.n_components] else: # None case to_ica = pca_data self.n_components = pca_data.shape[1] - self._comp_idx = np.arange(self.n_components) - else: # float case - expl_var = pca.explained_variance_ratio_ - self._comp_idx = (np.where(expl_var.cumsum() < - self._explained_var)[0]) - to_ica = pca_data[:, self._comp_idx] - self.n_components = len(self._comp_idx) - self._ica.fit(to_ica) - - if not hasattr(self._ica, 'sources_'): - self._unmixing = self._ica.unmixing_matrix_ - else: - self._unmixing = self._ica.components_ - self.current_fit = caller + # the things to store for PCA + self.pca_components_ = pca.components_ + self.pca_mean_ = pca.mean_ + self.pca_explained_variance_ = pca.explained_variance_ + # and store number of components as it may be smaller than + # pca.components_.shape[1] + self.n_ica_components_ = to_ica.shape[1] + + # Take care of ICA + try: + from sklearn.decomposition import FastICA # to avoid strong dep. + except ImportError: + raise Exception('the scikit-learn package is missing and ' + 'required for ICA') + + # sklearn < 0.11 does not support random_state argument for FastICA + kwargs = {'algorithm': self.algorithm, 'fun': self.fun, + 'fun_args': self.fun_args} - self._pca = ATTRIBUTES_PCA(** dict((k, vars(pca)[k]) for k in - ATTRIBUTES_PCA._fields)) + if self.random_state is not None: + aspec = inspect.getargspec(FastICA.__init__) + if 'random_state' not in aspec.args: + warnings.warn('random_state argument ignored, update ' + 'scikit-learn to version 0.11 or newer') + else: + kwargs['random_state'] = self.random_state + ica = FastICA(**kwargs) + ica.fit(to_ica) - self._params_pca = PARAMS_PCA(** dict((k, vars(pca).get(k, 'NA')) - for k in PARAMS_PCA._fields)) + # For ICA the only thing to store is the unmixing matrix + if not hasattr(ica, 'sources_'): + self.unmixing_matrix_ = ica.unmixing_matrix_ + else: + self.unmixing_matrix_ = ica.components_ - self._params_ica = PARAMS_ICA(** dict((k, vars(self._pca).get(k, 'NA')) - for k in PARAMS_ICA._fields)) - del self._ica + self.mixing_matrix_ = linalg.pinv(self.unmixing_matrix_).T + self.current_fit = fit_type def _pick_sources(self, sources, pca_data, include, exclude, n_pca_components): @@ -877,24 +850,21 @@ def _pick_sources(self, sources, pca_data, include, exclude, sources[exclude, :] = 0. # just exclude # restore pca data - pca_restored = np.dot(sources.T, self.mixing_matrix) + pca_restored = np.dot(sources.T, self.mixing_matrix_) # re-append deselected pca dimension if desired - if n_pca_components - self.n_components > 0: - add_components = np.arange(self.n_components, n_pca_components) - pca_reappend = pca_data[:, add_components] + if n_pca_components > self.n_ica_components_: + pca_reappend = pca_data[:, self.n_ica_components_:n_pca_components] pca_restored = np.c_[pca_restored, pca_reappend] # restore sensor space data - out = self._inverse_t_pca(pca_restored) + out = self._inverse_transform_pca(pca_restored) # restore scaling - pre_whitener = self._pre_whitener.copy() if self.noise_cov is None: # revert standardization - pre_whitener **= -1 - out *= pre_whitener + out /= self._pre_whitener else: - out = np.dot(out, linalg.pinv(pre_whitener)) + out = np.dot(out, linalg.pinv(self._pre_whitener)) return out.T @@ -902,24 +872,24 @@ def _transform_pca(self, data): """Apply decorrelation / dimensionality reduction on MEEG data. """ X = np.atleast_2d(data) - if self._pca.mean_ is not None: - X = X - self._pca.mean_ + if self.pca_mean_ is not None: + X = X - self.pca_mean_ - X = np.dot(X, self._pca.components_.T) + X = np.dot(X, self.pca_components_.T) return X def _transform_ica(self, data): """Apply ICA un-mixing matrix to recover the latent sources. """ - return np.dot(np.atleast_2d(data), self.unmixing_matrix.T) + return np.dot(np.atleast_2d(data), self.unmixing_matrix_.T) - def _inverse_t_pca(self, X): + def _inverse_transform_pca(self, X): """Helper Function""" - components = self._pca.components_[np.arange(len(X.T))] + components = self.pca_components_[:X.shape[1]] X_orig = np.dot(X, components) - if self._pca.mean_ is not None: - X_orig += self._pca.mean_ + if self.pca_mean_ is not None: + X_orig += self.pca_mean_ return X_orig @@ -1080,22 +1050,13 @@ def _write_ica(fid, ica): ica: The instance of ICA to write """ - - _params_ica = ica._params_ica._asdict() - for key in ('fun_args', 'fun_prime'): - v = _params_ica.get(key, None) - if v is not None: - _params_ica[key] = (_serialize(v, '#') if isinstance(v, dict) - else str(v)) - else: - _params_ica[key] = str(None) - ica_interface = dict(noise_cov=ica.noise_cov, max_n_components=ica.max_n_components, n_components=ica.n_components, current_fit=ica.current_fit, - _explained_var=ica._explained_var - ) + algorithm=ica.algorithm, + fun=ica.fun, + fun_args=ica.fun_args) start_block(fid, FIFF.FIFFB_ICA) @@ -1110,25 +1071,19 @@ def _write_ica(fid, ica): # Whitener write_double_matrix(fid, FIFF.FIFF_MNE_ICA_WHITENER, ica._pre_whitener) - # _PCA parameters - write_string(fid, FIFF.FIFF_MNE_ICA_PCA_PARAMS, - _serialize(ica._params_pca._asdict())) - - # _PCA components_ + # PCA components_ write_double_matrix(fid, FIFF.FIFF_MNE_ICA_PCA_COMPONENTS, - ica._pca.components_) + ica.pca_components_) - # _PCA explained_variance_ - write_double_matrix(fid, FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR, - ica._pca.explained_variance_) - # _PCA mean_ - write_double_matrix(fid, FIFF.FIFF_MNE_ICA_PCA_MEAN, ica._pca.mean_) + # PCA mean_ + write_double_matrix(fid, FIFF.FIFF_MNE_ICA_PCA_MEAN, ica.pca_mean_) - # _ICA parameters - write_string(fid, FIFF.FIFF_MNE_ICA_PARAMS, _serialize(_params_ica)) + # PCA explained_variance_ + write_double_matrix(fid, FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR, + ica.pca_explained_variance_) - # _ICA unmixing - write_double_matrix(fid, FIFF.FIFF_MNE_ICA_MATRIX, ica.unmixing_matrix) + # ICA unmixing + write_double_matrix(fid, FIFF.FIFF_MNE_ICA_MATRIX, ica.unmixing_matrix_) # Done! end_block(fid, FIFF.FIFFB_ICA) @@ -1136,16 +1091,17 @@ def _write_ica(fid, ica): @verbose def read_ica(fname): - """ Restore ICA sessions from fif file. + """Restore ICA sessions from fif file. Parameters ---------- fname : str - Absolute path to fif file containing ICA matrixes + Absolute path to fif file containing ICA matrices. Returns ------- ica : instance of ICA + The ICA estimator. """ try: from sklearn.decomposition import FastICA, RandomizedPCA @@ -1156,6 +1112,7 @@ def read_ica(fname): logger.info('Reading %s ...' % fname) fid, tree, _ = fiff_open(fname) ica_data = dir_tree_find(tree, FIFF.FIFFB_ICA) + if len(ica_data) == 0: fid.close() raise ValueError('Could not find ICA data') @@ -1172,22 +1129,16 @@ def read_ica(fname): ch_names = tag.data elif kind == FIFF.FIFF_MNE_ICA_WHITENER: tag = read_tag(fid, pos) - _pre_whitener = tag.data - elif kind == FIFF.FIFF_MNE_ICA_PCA_PARAMS: - tag = read_tag(fid, pos) - _params_pca = tag.data + pre_whitener = tag.data elif kind == FIFF.FIFF_MNE_ICA_PCA_COMPONENTS: tag = read_tag(fid, pos) - components_ = tag.data + pca_components = tag.data elif kind == FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR: tag = read_tag(fid, pos) - explained_variance_ = tag.data + pca_explained_variance = tag.data elif kind == FIFF.FIFF_MNE_ICA_PCA_MEAN: tag = read_tag(fid, pos) - mean_ = tag.data - elif kind == FIFF.FIFF_MNE_ICA_PARAMS: - tag = read_tag(fid, pos) - _params_ica = tag.data + pca_mean = tag.data elif kind == FIFF.FIFF_MNE_ICA_MATRIX: tag = read_tag(fid, pos) unmixing_matrix = tag.data @@ -1196,7 +1147,6 @@ def read_ica(fname): interface = _deserialize(ica_interface) current_fit = interface.pop('current_fit') - _explained_var = interface.pop('_explained_var') if interface['noise_cov'] == Covariance.__name__: logger.warning('The noise covariance used on fit cannot be restored.' 'The whitener drawn from the covariance will be used.') @@ -1206,21 +1156,13 @@ def read_ica(fname): ica = ICA(**interface) ica.current_fit = current_fit ica.ch_names = ch_names.split(':') - ica._comp_idx = np.arange(ica.n_components) - ica._pre_whitener = _pre_whitener - ica._explained_var = _explained_var - ica._unmixing = unmixing_matrix - - ica._params_ica = (PARAMS_ICA(** _deserialize(_params_ica)) if _params_ica - else None) - ica._params_pca = (PARAMS_PCA(** _deserialize(_params_pca)) if _params_pca - else None) - - ica._pca = ATTRIBUTES_PCA(components_=components_, - mean_=mean_, - explained_variance_=explained_variance_, - explained_variance_ratio_=explained_variance_ / \ - explained_variance_.sum()) + ica._pre_whitener = pre_whitener + ica.pca_mean_ = pca_mean + ica.pca_components_ = pca_components + ica.n_ica_components_ = unmixing_matrix.shape[0] + ica.pca_explained_variance_ = pca_explained_variance + ica.unmixing_matrix_ = unmixing_matrix + ica.mixing_matrix_ = linalg.pinv(ica.unmixing_matrix_).T logger.info('Ready.') return ica diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index b8d08878d75..ae8b61e303a 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -1,4 +1,5 @@ # Author: Denis Engemann +# Alexandre Gramfort # # License: BSD (3-clause) @@ -73,9 +74,9 @@ def test_ica_core(): max_n_components = [4] picks_ = [picks] iter_ica_params = product(noise_cov, n_components, max_n_components, - picks_) + picks_) - # test init catchers + # # test init catchers assert_raises(ValueError, ICA, n_components=3, max_n_components=2) assert_raises(ValueError, ICA, n_components=1.3, max_n_components=2) @@ -86,17 +87,19 @@ def test_ica_core(): random_state=0) print ica # to test repr + # test fit checker assert_raises(RuntimeError, ica.get_sources_raw, raw) assert_raises(RuntimeError, ica.get_sources_epochs, epochs) # test decomposition ica.decompose_raw(raw, picks=pcks, start=start, stop=stop) + print ica # to test repr # test re-init exception assert_raises(RuntimeError, ica.decompose_raw, raw, picks=picks) sources = ica.get_sources_raw(raw) - assert_true(sources.shape[0] == ica.n_components) + assert_true(sources.shape[0] == ica.n_ica_components_) # test preload filter raw3 = raw.copy() @@ -120,12 +123,13 @@ def test_ica_core(): random_state=0) ica.decompose_epochs(epochs, picks=picks) + print ica # to test repr # test pick block after epochs fit assert_raises(ValueError, ica.pick_sources_raw, raw, n_pca_components=ica.n_components) sources = ica.get_sources_epochs(epochs) - assert_true(sources.shape[1] == ica.n_components) + assert_true(sources.shape[1] == ica.n_ica_components_) assert_raises(ValueError, ica.find_sources_epochs, epochs, target=np.arange(1)) @@ -134,13 +138,13 @@ def test_ica_core(): epochs3 = epochs.copy() epochs3.preload = False assert_raises(ValueError, ica.pick_sources_epochs, epochs3, - include=[1, 2], n_pca_components=ica.n_components) + include=[1, 2], n_pca_components=ica.n_ica_components_) # test source picking for excl, incl in (([], []), ([], [1, 2]), ([1, 2], [])): epochs2 = ica.pick_sources_epochs(epochs, exclude=excl, - include=incl, copy=True, - n_pca_components=ica.n_components) + include=incl, copy=True, + n_pca_components=ica.n_ica_components_) assert_array_almost_equal(epochs2.get_data(), epochs.get_data()) @@ -177,17 +181,15 @@ def test_ica_additional(): assert_true(ica.ch_names == ica_read.ch_names) - assert_array_equal(ica.mixing_matrix, ica_read.mixing_matrix) - assert_array_equal(ica._pca.components_, - ica_read._pca.components_) - assert_array_equal(ica._pca.mean_, - ica_read._pca.mean_) - assert_array_equal(ica._pca.explained_variance_, - ica_read._pca.explained_variance_) - assert_array_equal(ica._pre_whitener, - ica_read._pre_whitener) - - assert_raises(RuntimeError, ica_read.decompose_raw, raw) + assert_array_equal(ica.mixing_matrix_, ica_read.mixing_matrix_) + assert_array_equal(ica.pca_components_, + ica_read.pca_components_) + assert_array_equal(ica.pca_mean_, ica_read.pca_mean_) + assert_array_equal(ica.pca_explained_variance_, + ica_read.pca_explained_variance_) + assert_array_equal(ica._pre_whitener, ica_read._pre_whitener) + + # assert_raises(RuntimeError, ica_read.decompose_raw, raw) sources = ica.get_sources_raw(raw) sources2 = ica_read.get_sources_raw(raw) assert_array_almost_equal(sources, sources2) diff --git a/mne/viz.py b/mne/viz.py index 18f606dc583..be4eb74842c 100644 --- a/mne/viz.py +++ b/mne/viz.py @@ -981,7 +981,8 @@ def _plot_ica_panel_onpick(event, sources=None, ylims=None): @verbose def plot_ica_panel(sources, start=None, stop=None, n_components=None, - source_idx=None, ncol=3, nrow=10, verbose=None): + source_idx=None, ncol=3, nrow=10, verbose=None, + show=True): """Create panel plots of ICA sources Clicking on the plot of an individual source opens a new figure showing @@ -1005,6 +1006,8 @@ def plot_ica_panel(sources, start=None, stop=None, n_components=None, Number of panel-rows. verbose : bool, str, int, or None If not None, override default verbose level (see mne.verbose). + show : bool + If True, plot will be shown, else just the figure is returned. Returns ------- @@ -1071,6 +1074,9 @@ def plot_ica_panel(sources, start=None, stop=None, n_components=None, callback = partial(_plot_ica_panel_onpick, sources=sources, ylims=ylims) fig.canvas.mpl_connect('pick_event', callback) + if show: + pl.show() + return fig