diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 8e512b6949e..0b98a03faa3 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -71,6 +71,8 @@ Enhancements - Add ``picks`` parameter to :func:`mne.preprocessing.fix_stim_artifact` to specify which channel needs to be fixed (:gh:`8482` by `Alex Gramfort`_) +- Add progress bar support to :func:`mne.time_frequency.csd_morlet` (:gh:`8608` by `Eric Larson`_) + - Further improved documentation building instructions and execution on Windows (:gh:`8502` by `kalenkovich`_ and `Eric Larson`_) - Add option to disable TQDM entirely with ``MNE_TQDM='off'`` (:gh:`8515` by `Eric Larson`_) diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index c0de757ab89..a5a2cd172d4 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -9,10 +9,11 @@ import numbers import numpy as np -from .tfr import cwt, morlet +from .tfr import _cwt_array, morlet, _get_nfft from ..fixes import rfftfreq from ..io.pick import pick_channels, _picks_to_idx -from ..utils import logger, verbose, warn, copy_function_doc_to_method_doc +from ..utils import (logger, verbose, warn, copy_function_doc_to_method_doc, + ProgressBar) from ..viz.misc import plot_csd from ..time_frequency.multitaper import (_compute_mt_params, _mt_spectra, _csd_from_mt, _psd_from_mt_adaptive) @@ -1021,9 +1022,10 @@ def csd_array_morlet(X, sfreq, frequencies, t0=0, tmin=None, tmax=None, times = times[csd_tslice] # Compute the CSD + nfft = _get_nfft(wavelets, X, use_fft) return _execute_csd_function(X, times, frequencies, _csd_morlet, - params=[sfreq, wavelets, csd_tslice, use_fft, - decim], + params=[sfreq, wavelets, nfft, csd_tslice, + use_fft, decim], n_fft=1, ch_names=ch_names, projs=projs, n_jobs=n_jobs, verbose=verbose) @@ -1140,14 +1142,8 @@ def _execute_csd_function(X, times, frequencies, csd_function, params, n_fft, # Compute CSD for each trial n_blocks = int(np.ceil(n_epochs / float(n_jobs))) - for i in range(n_blocks): + for i in ProgressBar(range(n_blocks), mesg='CSD epoch blocks'): epoch_block = X[i * n_jobs:(i + 1) * n_jobs] - if n_jobs > 1: - logger.info(' Computing CSD matrices for epochs %d..%d' - % (i * n_jobs + 1, (i + 1) * n_jobs)) - else: - logger.info(' Computing CSD matrix for epoch %d' % (i + 1)) - csds = parallel(my_csd(this_epoch, *params) for this_epoch in epoch_block) @@ -1274,7 +1270,8 @@ def _csd_multitaper(X, sfreq, n_times, window_fun, eigvals, freq_mask, n_fft, return csds -def _csd_morlet(data, sfreq, wavelets, tslice=None, use_fft=True, decim=1): +def _csd_morlet(data, sfreq, wavelets, nfft, tslice=None, use_fft=True, + decim=1): """Compute cross spectral density (CSD) using the given Morlet wavelets. Computes the CSD for a single epoch of data. @@ -1289,6 +1286,8 @@ def _csd_morlet(data, sfreq, wavelets, tslice=None, use_fft=True, decim=1): wavelets : list of ndarray The Morlet wavelets for which to compute the CSD's. These have been created by the `mne.time_frequency.tfr.morlet` function. + nfft : int + The number of FFT points. tslice : slice | None The desired time samples to compute the CSD over. If None, defaults to including all time samples. @@ -1314,7 +1313,8 @@ def _csd_morlet(data, sfreq, wavelets, tslice=None, use_fft=True, decim=1): _vector_to_sym_mat : For converting the CSD to a full matrix. """ # Compute PSD - psds = cwt(data, wavelets, use_fft=use_fft, decim=decim) + psds = _cwt_array(data, wavelets, nfft, mode='same', use_fft=use_fft, + decim=decim) if tslice is not None: tstart = None if tslice.start is None else tslice.start // decim diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 01fd66b386c..701291ffe9e 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -256,7 +256,7 @@ def test_time_frequency(): # When convolving in time, wavelets must not be longer than the data pytest.raises(ValueError, cwt, data[0, :, :Ws[0].size - 1], Ws, use_fft=False) - with pytest.warns(UserWarning, match='one of the wavelets is longer'): + with pytest.warns(UserWarning, match='one of the wavelets.*is longer'): cwt(data[0, :, :Ws[0].size - 1], Ws, use_fft=True) # Check for off-by-one errors when using wavelets with an even number of diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 74077a48a4b..873ec2ee7bf 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -20,6 +20,7 @@ from ..baseline import rescale from ..fixes import fft, ifft +from ..filter import next_fast_len from ..parallel import parallel_func from ..utils import (logger, verbose, _time_mask, _freq_mask, check_fname, sizeof_fmt, GetEpochsMixin, _prepare_read_metadata, @@ -172,7 +173,24 @@ def _make_dpss(sfreq, freqs, n_cycles=7., time_bandwidth=4.0, zero_mean=False): # Low level convolution -def _cwt(X, Ws, mode="same", decim=1, use_fft=True): +def _get_nfft(wavelets, X, use_fft=True, check=True): + n_times = X.shape[-1] + max_size = max(w.size for w in wavelets) + if max_size > n_times: + msg = (f'At least one of the wavelets ({max_size}) is longer than the ' + f'signal ({n_times}). Consider using a longer signal or ' + 'shorter wavelets.') + if check: + if use_fft: + warn(msg, UserWarning) + else: + raise ValueError(msg) + nfft = n_times + max_size - 1 + nfft = next_fast_len(nfft) # 2 ** int(np.ceil(np.log2(nfft))) + return nfft + + +def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): """Compute cwt with fft based convolutions or temporal convolutions. Parameters @@ -181,6 +199,8 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True): The data. Ws : list of array Wavelets time series. + fsize : int + FFT length. mode : {'full', 'valid', 'same'} See numpy.convolve. decim : int | slice, default 1 @@ -204,31 +224,15 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True): X = np.asarray(X) # Precompute wavelets for given frequency range to save time - n_signals, n_times = X.shape + _, n_times = X.shape n_times_out = X[:, decim].shape[1] n_freqs = len(Ws) - Ws_max_size = max(W.size for W in Ws) - size = n_times + Ws_max_size - 1 - # Always use 2**n-sized FFT - fsize = 2 ** int(np.ceil(np.log2(size))) - # precompute FFTs of Ws if use_fft: fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128) - - warn_me = True - for i, W in enumerate(Ws): - if use_fft: + for i, W in enumerate(Ws): fft_Ws[i] = fft(W, fsize) - if len(W) > n_times and warn_me: - msg = ('At least one of the wavelets is longer than the signal. ' - 'Consider padding the signal or using shorter wavelets.') - if use_fft: - warn(msg, UserWarning) - warn_me = False # Suppress further warnings - else: - raise ValueError(msg) # Make generator looping across signals tfr = np.zeros((n_freqs, n_times_out), dtype=np.complex128) @@ -380,6 +384,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) # Parallel computation + all_Ws = sum([list(W) for W in Ws], list()) + _get_nfft(all_Ws, epoch_data, use_fft) parallel, my_cwt, _ = parallel_func(_time_frequency_loop, n_jobs) # Parallelization is applied across channels. @@ -510,7 +516,10 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): # Loops across tapers. for W in Ws: - coefs = _cwt(X, W, mode, decim=decim, use_fft=use_fft) + # No need to check here, it's done earlier (outside parallel part) + nfft = _get_nfft(W, X, use_fft, check=False) + coefs = _cwt_gen( + X, W, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) # Inter-trial phase locking is apparently computed per taper... if 'itc' in output: @@ -586,11 +595,16 @@ def cwt(X, Ws, use_fft=True, mode='same', decim=1): mne.time_frequency.tfr_morlet : Compute time-frequency decomposition with Morlet wavelets. """ - decim = _check_decim(decim) - n_signals, n_times = X[:, decim].shape + nfft = _get_nfft(Ws, X, use_fft) + return _cwt_array(X, Ws, nfft, mode, decim, use_fft) + - coefs = _cwt(X, Ws, mode, decim=decim, use_fft=use_fft) +def _cwt_array(X, Ws, nfft, mode, decim, use_fft): + decim = _check_decim(decim) + coefs = _cwt_gen( + X, Ws, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) + n_signals, n_times = X[:, decim].shape tfrs = np.empty((n_signals, len(Ws), n_times), dtype=np.complex128) for k, tfr in enumerate(coefs): tfrs[k] = tfr diff --git a/mne/utils/tests/test_progressbar.py b/mne/utils/tests/test_progressbar.py index c7e25282b76..c259a8b140a 100644 --- a/mne/utils/tests/test_progressbar.py +++ b/mne/utils/tests/test_progressbar.py @@ -24,7 +24,8 @@ def test_progressbar(): def iter_func(a): for ii in a: pass - pytest.raises(Exception, iter_func, ProgressBar(20)) + with pytest.raises(TypeError, match='not iterable'): + iter_func(pbar) # Make sure different progress bars can be used with catch_logging() as log, modified_env(MNE_TQDM='tqdm'), \