Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ Enhancements

- Add ``picks`` parameter to :func:`mne.preprocessing.fix_stim_artifact` to specify which channel needs to be fixed (:gh:`8482` by `Alex Gramfort`_)

- Add progress bar support to :func:`mne.time_frequency.csd_morlet` (:gh:`8608` by `Eric Larson`_)

- Further improved documentation building instructions and execution on Windows (:gh:`8502` by `kalenkovich`_ and `Eric Larson`_)

- Add option to disable TQDM entirely with ``MNE_TQDM='off'`` (:gh:`8515` by `Eric Larson`_)
Expand Down
26 changes: 13 additions & 13 deletions mne/time_frequency/csd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import numbers

import numpy as np
from .tfr import cwt, morlet
from .tfr import _cwt_array, morlet, _get_nfft
from ..fixes import rfftfreq
from ..io.pick import pick_channels, _picks_to_idx
from ..utils import logger, verbose, warn, copy_function_doc_to_method_doc
from ..utils import (logger, verbose, warn, copy_function_doc_to_method_doc,
ProgressBar)
from ..viz.misc import plot_csd
from ..time_frequency.multitaper import (_compute_mt_params, _mt_spectra,
_csd_from_mt, _psd_from_mt_adaptive)
Expand Down Expand Up @@ -1021,9 +1022,10 @@ def csd_array_morlet(X, sfreq, frequencies, t0=0, tmin=None, tmax=None,
times = times[csd_tslice]

# Compute the CSD
nfft = _get_nfft(wavelets, X, use_fft)
return _execute_csd_function(X, times, frequencies, _csd_morlet,
params=[sfreq, wavelets, csd_tslice, use_fft,
decim],
params=[sfreq, wavelets, nfft, csd_tslice,
use_fft, decim],
n_fft=1, ch_names=ch_names, projs=projs,
n_jobs=n_jobs, verbose=verbose)

Expand Down Expand Up @@ -1140,14 +1142,8 @@ def _execute_csd_function(X, times, frequencies, csd_function, params, n_fft,

# Compute CSD for each trial
n_blocks = int(np.ceil(n_epochs / float(n_jobs)))
for i in range(n_blocks):
for i in ProgressBar(range(n_blocks), mesg='CSD epoch blocks'):
epoch_block = X[i * n_jobs:(i + 1) * n_jobs]
if n_jobs > 1:
logger.info(' Computing CSD matrices for epochs %d..%d'
% (i * n_jobs + 1, (i + 1) * n_jobs))
else:
logger.info(' Computing CSD matrix for epoch %d' % (i + 1))

csds = parallel(my_csd(this_epoch, *params)
for this_epoch in epoch_block)

Expand Down Expand Up @@ -1274,7 +1270,8 @@ def _csd_multitaper(X, sfreq, n_times, window_fun, eigvals, freq_mask, n_fft,
return csds


def _csd_morlet(data, sfreq, wavelets, tslice=None, use_fft=True, decim=1):
def _csd_morlet(data, sfreq, wavelets, nfft, tslice=None, use_fft=True,
decim=1):
"""Compute cross spectral density (CSD) using the given Morlet wavelets.

Computes the CSD for a single epoch of data.
Expand All @@ -1289,6 +1286,8 @@ def _csd_morlet(data, sfreq, wavelets, tslice=None, use_fft=True, decim=1):
wavelets : list of ndarray
The Morlet wavelets for which to compute the CSD's. These have been
created by the `mne.time_frequency.tfr.morlet` function.
nfft : int
The number of FFT points.
tslice : slice | None
The desired time samples to compute the CSD over. If None, defaults to
including all time samples.
Expand All @@ -1314,7 +1313,8 @@ def _csd_morlet(data, sfreq, wavelets, tslice=None, use_fft=True, decim=1):
_vector_to_sym_mat : For converting the CSD to a full matrix.
"""
# Compute PSD
psds = cwt(data, wavelets, use_fft=use_fft, decim=decim)
psds = _cwt_array(data, wavelets, nfft, mode='same', use_fft=use_fft,
decim=decim)

if tslice is not None:
tstart = None if tslice.start is None else tslice.start // decim
Expand Down
2 changes: 1 addition & 1 deletion mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_time_frequency():
# When convolving in time, wavelets must not be longer than the data
pytest.raises(ValueError, cwt, data[0, :, :Ws[0].size - 1], Ws,
use_fft=False)
with pytest.warns(UserWarning, match='one of the wavelets is longer'):
with pytest.warns(UserWarning, match='one of the wavelets.*is longer'):
cwt(data[0, :, :Ws[0].size - 1], Ws, use_fft=True)

# Check for off-by-one errors when using wavelets with an even number of
Expand Down
60 changes: 37 additions & 23 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..baseline import rescale
from ..fixes import fft, ifft
from ..filter import next_fast_len
from ..parallel import parallel_func
from ..utils import (logger, verbose, _time_mask, _freq_mask, check_fname,
sizeof_fmt, GetEpochsMixin, _prepare_read_metadata,
Expand Down Expand Up @@ -172,7 +173,24 @@ def _make_dpss(sfreq, freqs, n_cycles=7., time_bandwidth=4.0, zero_mean=False):

# Low level convolution

def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
def _get_nfft(wavelets, X, use_fft=True, check=True):
n_times = X.shape[-1]
max_size = max(w.size for w in wavelets)
if max_size > n_times:
msg = (f'At least one of the wavelets ({max_size}) is longer than the '
f'signal ({n_times}). Consider using a longer signal or '
'shorter wavelets.')
if check:
if use_fft:
warn(msg, UserWarning)
else:
raise ValueError(msg)
nfft = n_times + max_size - 1
nfft = next_fast_len(nfft) # 2 ** int(np.ceil(np.log2(nfft)))
return nfft


def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True):
"""Compute cwt with fft based convolutions or temporal convolutions.

Parameters
Expand All @@ -181,6 +199,8 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
The data.
Ws : list of array
Wavelets time series.
fsize : int
FFT length.
mode : {'full', 'valid', 'same'}
See numpy.convolve.
decim : int | slice, default 1
Expand All @@ -204,31 +224,15 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
X = np.asarray(X)

# Precompute wavelets for given frequency range to save time
n_signals, n_times = X.shape
_, n_times = X.shape
n_times_out = X[:, decim].shape[1]
n_freqs = len(Ws)

Ws_max_size = max(W.size for W in Ws)
size = n_times + Ws_max_size - 1
# Always use 2**n-sized FFT
fsize = 2 ** int(np.ceil(np.log2(size)))

# precompute FFTs of Ws
if use_fft:
fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128)

warn_me = True
for i, W in enumerate(Ws):
if use_fft:
for i, W in enumerate(Ws):
fft_Ws[i] = fft(W, fsize)
if len(W) > n_times and warn_me:
msg = ('At least one of the wavelets is longer than the signal. '
'Consider padding the signal or using shorter wavelets.')
if use_fft:
warn(msg, UserWarning)
warn_me = False # Suppress further warnings
else:
raise ValueError(msg)

# Make generator looping across signals
tfr = np.zeros((n_freqs, n_times_out), dtype=np.complex128)
Expand Down Expand Up @@ -380,6 +384,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

# Parallel computation
all_Ws = sum([list(W) for W in Ws], list())
_get_nfft(all_Ws, epoch_data, use_fft)
parallel, my_cwt, _ = parallel_func(_time_frequency_loop, n_jobs)

# Parallelization is applied across channels.
Expand Down Expand Up @@ -510,7 +516,10 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):

# Loops across tapers.
for W in Ws:
coefs = _cwt(X, W, mode, decim=decim, use_fft=use_fft)
# No need to check here, it's done earlier (outside parallel part)
nfft = _get_nfft(W, X, use_fft, check=False)
coefs = _cwt_gen(
X, W, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft)

# Inter-trial phase locking is apparently computed per taper...
if 'itc' in output:
Expand Down Expand Up @@ -586,11 +595,16 @@ def cwt(X, Ws, use_fft=True, mode='same', decim=1):
mne.time_frequency.tfr_morlet : Compute time-frequency decomposition
with Morlet wavelets.
"""
decim = _check_decim(decim)
n_signals, n_times = X[:, decim].shape
nfft = _get_nfft(Ws, X, use_fft)
return _cwt_array(X, Ws, nfft, mode, decim, use_fft)


coefs = _cwt(X, Ws, mode, decim=decim, use_fft=use_fft)
def _cwt_array(X, Ws, nfft, mode, decim, use_fft):
decim = _check_decim(decim)
coefs = _cwt_gen(
X, Ws, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft)

n_signals, n_times = X[:, decim].shape
tfrs = np.empty((n_signals, len(Ws), n_times), dtype=np.complex128)
for k, tfr in enumerate(coefs):
tfrs[k] = tfr
Expand Down
3 changes: 2 additions & 1 deletion mne/utils/tests/test_progressbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_progressbar():
def iter_func(a):
for ii in a:
pass
pytest.raises(Exception, iter_func, ProgressBar(20))
with pytest.raises(TypeError, match='not iterable'):
iter_func(pbar)

# Make sure different progress bars can be used
with catch_logging() as log, modified_env(MNE_TQDM='tqdm'), \
Expand Down