From 0d8d09c857872fd2b01a2f44a397b0d7dbcfc510 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 22 Jul 2024 21:51:27 +0200 Subject: [PATCH 01/16] Reimplement complex support --- mne/time_frequency/spectrum.py | 250 +++++++++++++------ mne/time_frequency/tests/test_spectrum.py | 280 ++++++++++++++++++++-- 2 files changed, 434 insertions(+), 96 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 45dadf9741a..7f4c2979710 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -55,7 +55,7 @@ _prepare_sensor_names, plt_show, ) -from .multitaper import psd_array_multitaper +from .multitaper import _psd_from_mt, psd_array_multitaper from .psd import _check_nfft, psd_array_welch @@ -314,13 +314,7 @@ def __init__( # method self._inst_type = type(inst) method = _validate_method(method, _get_instance_type_string(self)) - # don't allow complex output psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) - if method_kw.get("output", "") == "complex": - raise ValueError( - f"Complex output is not supported in {type(self).__name__} objects. " - f"Please use mne.time_frequency.{psd_funcs[method].__name__}() instead." - ) # triage method and kwargs. partial() doesn't check validity of kwargs, # so we do it manually to save compute time if any are invalid. psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) @@ -352,9 +346,12 @@ def __init__( ) if method_kw.get("average", "") in (None, False): self._dims += ("segment",) + if self._returns_complex_tapers(**method_kw): + self._dims = self._dims[:-1] + ("taper",) + self._dims[-1:] # record data type (for repr and html_repr) self._data_type = ( - "Fourier Coefficients" if "taper" in self._dims else "Power Spectrum" + f"{'Complex' if method_kw.get('output', '') == 'complex' else 'Real'} " + "Spectrum" ) # set nave (child constructor overrides this for Evoked input) self._nave = None @@ -376,6 +373,7 @@ def __getstate__(self): data_type=self._data_type, info=self.info, nave=self.nave, + mt_weights=self.mt_weights, ) return out @@ -393,6 +391,7 @@ def __setstate__(self, state): self.info = Info(**state["info"]) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` + self._mt_weights = state.get("mt_weights") # objs saved before #XXX won't have self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) @@ -440,14 +439,24 @@ def _check_values(self): s = _pl(bad_value.sum()) warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) + def _returns_complex_tapers(self, **method_kw): + return self.method == "multitaper" and method_kw.get("output", "") == "complex" + def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): # make the spectra result = self._psd_func( data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose ) - # assign ._data ._freqs, ._shape - psds, freqs = result - self._data = psds + # assign ._data (handling unaggregated multitaper output) + if self._returns_complex_tapers(**method_kw): + fourier_coefs, freqs, weights = result + self._data = fourier_coefs + self._mt_weights = weights + else: + psds, freqs = result + self._data = psds + self._mt_weights = None + # assign properties (._data already assigned above) self._freqs = freqs # this is *expected* shape, it gets asserted later in _check_values() # (and then deleted afterwards) @@ -456,6 +465,9 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): if method_kw.get("average", "") in (None, False): n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw) self._shape += (n_welch_segments,) + # insert n_tapers + if self._returns_complex_tapers(**method_kw): + self._shape = self._shape[:-1] + (self._mt_weights.size,) + self._shape[-1:] # we don't need these anymore, and they make save/load harder del self._picks del self._psd_func @@ -486,6 +498,10 @@ def method(self): def nave(self): return self._nave + @property + def mt_weights(self): + return self._mt_weights + @property def sfreq(self): return self._sfreq @@ -643,34 +659,13 @@ def plot( (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( self, picks, units, scalings, titles ) - # handle unaggregated Welch - if "segment" in self._dims: - logger.info("Aggregating Welch estimates (median) before plotting...") - seg_axis = self._dims.index("segment") - _f = partial(np.nanmedian, axis=seg_axis) - else: # "normal" cases - _f = _identity_function - ch_axis = self._dims.index("channel") - psd_list = [_f(self._data.take(_p, axis=ch_axis)) for _p in picks_list] - # handle epochs - if "epoch" in self._dims: - # XXX TODO FIXME decide how to properly aggregate across repeated - # measures (epochs) and non-repeated but correlated measures - # (channels) when calculating stddev or a CI. For across-channel - # aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2 - # with a correction factor that estimated data rank using monte - # carlo simulations; seems like we could use our own data rank - # estimation methods to similar effect. Their exact approach used - # complex spectra though, here we've already converted to power; - # not sure if that makes an important difference? Anyway that - # aggregation would need to happen in the _plot_psd function - # though, not here... for now we just average like we always did. - - # only log message if averaging will actually have an effect - if self._data.shape[0] > 1: - logger.info("Averaging across epochs...") - # epoch axis should always be the first axis - psd_list = [_p.mean(axis=0) for _p in psd_list] + # prepare data (e.g. aggregate across dims, convert complex to power) + psd_list = [ + self._prepare_data_for_plot( + self._data.take(_p, axis=self._dims.index("channel")) + ) + for _p in picks_list + ] # initialize figure fig, axes = _line_figure(self, axes, picks=picks) # don't add ylabels & titles if figure has unexpected number of axes @@ -739,8 +734,8 @@ def plot_topo( layout = find_layout(self.info) psds, freqs = self.get_data(return_freqs=True) - if "epoch" in self._dims: - psds = np.mean(psds, axis=self._dims.index("epoch")) + # prepare data (e.g. aggregate across dims, convert complex to power) + psds = self._prepare_data_for_plot(psds) if dB: psds = 10 * np.log10(psds) y_label = "dB" @@ -852,8 +847,8 @@ def plot_topomap( outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) psds, freqs = self.get_data(picks=picks, return_freqs=True) - if "epoch" in self._dims: - psds = np.mean(psds, axis=self._dims.index("epoch")) + # prepare data (e.g. aggregate across dims, convert complex to power) + psds = self._prepare_data_for_plot(psds) psds *= scaling**2 if merge_channels: @@ -891,6 +886,43 @@ def plot_topomap( show=show, ) + def _prepare_data_for_plot(self, data): + # handle unaggregated Welch + if "segment" in self._dims: + seg_axis = self._dims.index("segment") + logger.info("Aggregating Welch estimates (median) before plotting...") + data = np.nanmedian(data, axis=seg_axis) + # handle unaggregated multitaper (also handles complex -> power) + elif "taper" in self._dims: + logger.info("Aggregating multitaper estimates before plotting...") + data = _psd_from_mt(data, self.mt_weights) + + # handle complex data (should only be Welch remaining) + if np.iscomplexobj(data): + data = (data * data.conj()).real # Scaling may be slightly off + + # handle epochs + if "epoch" in self._dims: + # XXX TODO FIXME decide how to properly aggregate across repeated + # measures (epochs) and non-repeated but correlated measures + # (channels) when calculating stddev or a CI. For across-channel + # aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2 + # with a correction factor that estimated data rank using monte + # carlo simulations; seems like we could use our own data rank + # estimation methods to similar effect. Their exact approach used + # complex spectra though, here we've already converted to power; + # not sure if that makes an important difference? Anyway that + # aggregation would need to happen in the _plot_psd function + # though, not here... for now we just average like we always did. + + # only log message if averaging will actually have an effect + if data.shape[0] > 1: + logger.info("Averaging across epochs before plotting...") + # epoch axis should always be the first axis + data = data.mean(axis=0) + + return data + @verbose def save(self, fname, *, overwrite=False, verbose=None): """Save spectrum data to disk (in HDF5 format). @@ -1062,6 +1094,9 @@ class Spectrum(BaseSpectrum): nave : int | None The number of trials averaged together when generating the spectrum. ``None`` indicates no averaging is known to have occurred. + mt_weights : array | None + The weights for each taper. Only present if spectra computed with + ``method='multitaper'`` and ``output='complex'``. See Also -------- @@ -1179,23 +1214,55 @@ def __getitem__(self, item): return BaseRaw._getitem(self, item, return_times=False) -def _check_data_shape(data, freqs, info, ndim): - if data.ndim != ndim: - raise ValueError(f"Data must be a {ndim}D array.") +def _check_data_shape( + data, info, freqs, is_epoched, events, method, standard_ndim, extra_ndim, mt_weights +): + _check_option("data.ndim", data.ndim, (standard_ndim, extra_ndim)) + + dims = () + if is_epoched: + dims += ("epoch",) + if events is not None and data.shape[0] != events.shape[0]: + raise ValueError( + f"The first dimension of `data` ({data.shape[0]}) must match the first " + f"dimension of `events` ({events.shape[0]})." + ) + + dims += ("channel",) want_n_chan = _pick_data_channels(info).size - want_n_freq = freqs.size - got_n_chan, got_n_freq = data.shape[-2:] + got_n_chan = data.shape[list(dims).index("channel")] if got_n_chan != want_n_chan: raise ValueError( - f"The number of channels in `data` ({got_n_chan}) must match the " - f"number of good data channels in `info` ({want_n_chan})." + f"The number of channels in `data` ({got_n_chan}) must match the number of " + f"good data channels in `info` ({want_n_chan})." ) + + if data.ndim == extra_ndim: # i.e. segments or tapers present + _check_option( + "method", method, ("multitaper", "welch"), f" when data.ndim={extra_ndim}" + ) # require method specified to differentiate segments from tapers + if method == "multitaper": + actual = None if mt_weights is None else mt_weights.size + expected = data.shape[-2] + if actual != expected: + raise ValueError( + f"Expected size of `mt_weights` to be {expected}, got {actual}." + ) + dims += ("taper", "freq") + else: # i.e. welch + dims += ("freq", "segment") + else: + dims += ("freq",) + want_n_freq = freqs.size + got_n_freq = data.shape[list(dims).index("freq")] if got_n_freq != want_n_freq: raise ValueError( - f"The last dimension of `data` ({got_n_freq}) must have the same " - f"number of elements as `freqs` ({want_n_freq})." + f"The number of frequencies in `data` ({got_n_freq}) must match the number " + f"of elements in `freqs` ({want_n_freq})." ) + return dims + @fill_doc class SpectrumArray(Spectrum): @@ -1203,10 +1270,17 @@ class SpectrumArray(Spectrum): Parameters ---------- - data : array, shape (n_channels, n_freqs) - The power spectral density for each channel. + data : ndarray, shape (n_channels, [n_tapers], n_freqs, [n_segments]) + The spectra for each channel. %(info_not_none)s %(freqs_tfr_array)s + method : str + The spectral estimation method. Default ``'unknown'``. Must be provided if data + contains unaggregated tapers (``method='multitaper'``) or segments + (``method='welch'``). + mt_weights : ndarray | None + The multitaper weights used for averaging across tapers. Only required if data + from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``. %(verbose)s See Also @@ -1229,21 +1303,34 @@ def __init__( data, info, freqs, + method="unknown", + mt_weights=None, *, verbose=None, ): - _check_data_shape(data, freqs, info, ndim=2) + dims = _check_data_shape( + data=data, + info=info, + freqs=freqs, + is_epoched=False, + events=None, + method=method, + standard_ndim=2, + extra_ndim=3, + mt_weights=mt_weights, + ) self.__setstate__( dict( - method="unknown", + method=method, data=data, sfreq=info["sfreq"], - dims=("channel", "freq"), + dims=dims, freqs=freqs, inst_type_str="Array", - data_type="Power Spectrum", + data_type=f"{'Complex' if np.iscomplexobj(data) else 'Real'} Spectrum", info=info, + mt_weights=mt_weights, ) ) @@ -1280,7 +1367,10 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): have been computed. %(info_not_none)s method : str - The method used to compute the spectrum ('welch' or 'multitaper'). + The method used to compute the spectrum (``'welch'`` or ``'multitaper'``). + mt_weights : array | None + The weights for each taper. Only present if spectra computed with + ``method='multitaper'`` and ``output='complex'``. See Also -------- @@ -1420,6 +1510,11 @@ def average(self, method="mean"): "supported. Consider averaging the signals before computing " "the Welch spectrum estimates." ) + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs is not supported. Consider " + "averaging the signals before computing the complex spectrum." + ) # serialize the object and update data, dims, and data type state = super().__getstate__() state["nave"] = state["data"].shape[0] @@ -1449,12 +1544,19 @@ class EpochsSpectrumArray(EpochsSpectrum): Parameters ---------- - data : array, shape (n_epochs, n_channels, n_freqs) - The power spectral density for each channel in each epoch. + data : ndarray, shape (n_epochs, n_channels, [n_tapers], n_freqs, [n_segments]) + The spectra for each channel in each epoch. %(info_not_none)s %(freqs_tfr_array)s %(events_epochs)s %(event_id)s + method : str + The spectral estimation method. Default ``'unknown'``. Must be provided if data + contains unaggregated tapers (``method='multitaper'``) or segments + (``method='welch'``). + mt_weights : ndarray | None + The multitaper weights used for averaging across tapers. Only required if data + from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``. %(verbose)s See Also @@ -1478,31 +1580,39 @@ def __init__( freqs, events=None, event_id=None, + method="unknown", + mt_weights=None, *, verbose=None, ): - _check_data_shape(data, freqs, info, ndim=3) - if events is not None and data.shape[0] != events.shape[0]: - raise ValueError( - f"The first dimension of `data` ({data.shape[0]}) must match the " - f"first dimension of `events` ({events.shape[0]})." - ) + dims = _check_data_shape( + data=data, + info=info, + freqs=freqs, + is_epoched=True, + events=events, + method=method, + standard_ndim=3, + extra_ndim=4, + mt_weights=mt_weights, + ) self.__setstate__( dict( - method="unknown", + method=method, data=data, sfreq=info["sfreq"], - dims=("epoch", "channel", "freq"), + dims=dims, freqs=freqs, inst_type_str="Array", - data_type="Power Spectrum", + data_type=f"{'Complex' if np.iscomplexobj(data) else 'Real'} Spectrum", info=info, events=events, event_id=event_id, metadata=None, selection=np.arange(data.shape[0]), drop_log=tuple(tuple() for _ in range(data.shape[0])), + mt_weights=mt_weights, ) ) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index a44c6aeaa17..00194a86a71 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -6,10 +6,12 @@ import numpy as np import pytest from matplotlib.colors import same_color -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal -from mne import Annotations +from mne import Annotations, create_info, make_fixed_length_epochs +from mne.io import RawArray from mne.time_frequency import read_spectrum +from mne.time_frequency.multitaper import _psd_from_mt from mne.time_frequency.spectrum import EpochsSpectrumArray, SpectrumArray from mne.utils import _record_warnings @@ -22,8 +24,6 @@ def test_compute_psd_errors(raw): raw.compute_psd(foo=None) with pytest.raises(TypeError, match="keyword arguments foo, bar for"): raw.compute_psd(foo=None, bar=None) - with pytest.raises(ValueError, match="Complex output is not supported in "): - raw.compute_psd(output="complex") raw.set_annotations(Annotations(onset=0.01, duration=0.01, description="bad_foo")) with pytest.raises(NotImplementedError, match='Cannot use method="multitaper"'): raw.compute_psd(method="multitaper", reject_by_annotation=True) @@ -33,7 +33,7 @@ def test_compute_psd_errors(raw): @pytest.mark.parametrize( ( "fmin, fmax, tmin, tmax, picks, proj, n_fft, n_overlap, n_per_seg, " - "average, window, bandwidth, adaptive, low_bias, normalization" + "average, window, bandwidth, adaptive, low_bias, normalization, output" ), [ [ @@ -52,6 +52,7 @@ def test_compute_psd_errors(raw): False, True, "length", + "power", ], # defaults [ 5, @@ -69,7 +70,26 @@ def test_compute_psd_errors(raw): True, False, "full", + "power", # XXX: technically a default ], # non-defaults + [ + 0, + np.inf, + None, + None, + None, + False, + 256, + 0, + None, + "mean", + "hamming", + None, + False, + True, + "length", + "complex", + ], # complex XXX: need to also test with non-defaults? ], ) def test_spectrum_params( @@ -89,6 +109,7 @@ def test_spectrum_params( adaptive, low_bias, normalization, + output, raw, ): """Test valid parameter combinations in the .compute_psd() method.""" @@ -100,6 +121,7 @@ def test_spectrum_params( tmax=tmax, picks=picks, proj=proj, + output=output, ) if method == "welch": kwargs.update( @@ -260,6 +282,69 @@ def test_spectrum_to_data_frame(inst, request, evoked): assert_frame_equal(_pick_first, _pick_last) +def _complex_helper(df, weights, group_cols): + """Convert complex spectrum to power after conversion to DataFrame.""" + from pandas import Series + + unagged_columns = df[group_cols].iloc[0].values.tolist() + x = df.drop(columns=group_cols).values[np.newaxis].T + if weights is None: + psd = np.mean((x * x.conj()).real * 2, axis=1) + else: + psd = _psd_from_mt(x, weights) + psd = np.atleast_1d(np.squeeze(psd)).tolist() + _df = dict(zip(df.columns, unagged_columns + psd)) + return Series(_df) + + +@pytest.mark.parametrize("long_format", (False, True)) +@pytest.mark.parametrize( + "method, output", + [("welch", "complex"), ("welch", "power"), ("multitaper", "complex")], +) +def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output): + """Test converting unaggregated spectra (multiple segments/tapers) to data frame.""" + pytest.importorskip("pandas") + from pandas.testing import assert_frame_equal + + from mne.utils.dataframe import _inplace + + # aggregated spectrum → dataframe + orig_df = raw.compute_psd(method=method).to_data_frame(long_format=long_format) + # unaggregated welch or complex multitaper → + # aggregate w/ pandas (to make sure we did reshaping right) + kwargs = dict() + if method == "welch": + kwargs.update(average=False) + spectrum = raw.compute_psd(method=method, output=output, **kwargs) + df = spectrum.to_data_frame(long_format=long_format) + grouping_cols = ["freq"] + drop_cols = ["segment"] if method == "welch" else ["taper"] + if long_format: + grouping_cols.append("channel") + drop_cols.append("ch_type") + orig_df.drop(columns="ch_type", inplace=True) + # only do a couple freq bins, otherwise test takes forever for multitaper + subset = partial(np.isin, test_elements=spectrum.freqs[:2]) + df = df.loc[subset(df["freq"])] + orig_df = orig_df.loc[subset(orig_df["freq"])] + # sort orig_df, because at present we can't actually prevent pandas from + # sorting at the agg step *sigh* + _inplace(orig_df, "sort_values", by=grouping_cols, ignore_index=True) + # aggregate + df = df.drop(columns=drop_cols) + gb = df.groupby(grouping_cols, as_index=False, observed=False) + if output == "complex": + gb = gb[df.columns] # https://github.com/pandas-dev/pandas/pull/52477 + agg_df = gb.apply(_complex_helper, spectrum.mt_weights, grouping_cols) + else: + agg_df = gb.mean() # excludes missing values itself + # even with check_categorical=False, we know that the *data* matches; + # what may differ is the order of the "levels" in the *metadata* for the + # channel name column + assert_frame_equal(agg_df, orig_df, check_categorical=False) + + # not testing with Evoked because it already has projs applied @pytest.mark.parametrize("inst", ("raw", "epochs")) def test_spectrum_proj(inst, request): @@ -275,6 +360,58 @@ def test_spectrum_proj(inst, request): assert has_proj == no_proj +@pytest.mark.parametrize( + "method, average", [("welch", False), ("welch", "mean"), ("multitaper", None)] +) +def test_spectrum_complex(method, average): + """Test output='complex' support.""" + sfreq = 100 + n = 10 * sfreq + freq = 3.0 + phase = np.pi / 4 # should be recoverable + data = np.cos(2 * np.pi * freq * np.arange(n) / sfreq + phase)[np.newaxis] + raw = RawArray(data, create_info(1, sfreq, "eeg")) + epochs = make_fixed_length_epochs(raw, duration=2.0, preload=True) + assert len(epochs) == 5 + assert len(epochs.times) == 2 * sfreq + kwargs = dict(output="complex", method=method) + if method == "welch": + kwargs["n_fft"] = sfreq + want_dims = ("epoch", "channel", "freq") + want_shape = (5, 1, sfreq // 2 + 1) + if not average: + want_dims = want_dims + ("segment",) + want_shape = want_shape + (2,) + kwargs["average"] = average + else: + assert method == "multitaper" + assert not average + want_dims = ("epoch", "channel", "taper", "freq") + want_shape = (5, 1, 7, sfreq + 1) + spectrum = epochs.compute_psd(**kwargs) + idx = np.argmin(np.abs(spectrum.freqs - freq)) + assert spectrum.freqs[idx] == freq + assert spectrum._dims == want_dims + assert spectrum.shape == want_shape + data = spectrum.get_data() + assert data.dtype == np.complex128 + coef = spectrum.get_data(fmin=freq, fmax=freq).mean(0) + if method == "multitaper": + coef = coef[..., 0, :] # first taper + elif not average: + coef = coef.mean(-1) # over segments + coef = coef.item() + # Test phase matches what was simulated + assert_allclose(np.angle(coef), phase, rtol=1e-4) + # Now test that it warns appropriately + epochs._data[0, 0, :] = 0 # actually zero for one epoch and ch + with pytest.warns(UserWarning, match="Zero value.*channel 0"): + epochs.compute_psd(**kwargs) + # But not if we mark that channel as bad + epochs.info["bads"] = epochs.ch_names[:1] + epochs.compute_psd(**kwargs) + + def test_spectrum_kwarg_triaging(raw): """Test kwarg triaging in legacy plot_psd() method.""" import matplotlib.pyplot as plt @@ -295,44 +432,135 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path): assert_array_equal(spect1.freqs, spect2.freqs) -def test_spectrum_array_errors(epochs_spectrum): - """Test EpochsSpectrumArray constructor errors.""" - data, freqs = epochs_spectrum.get_data(return_freqs=True) - info = epochs_spectrum.info - with pytest.raises(ValueError, match="Data must be a 3D array"): - EpochsSpectrumArray(np.empty((2, 3, 4, 5)), info, freqs) +@pytest.mark.parametrize("kind", ("raw", "epochs")) +@pytest.mark.parametrize( + "method, output, average", + [ + ("welch", "power", "mean"), # test with precomputed spectrum + ("welch", "power", False), # unaggregated segments + ("multitaper", "complex", None), # unaggregated tapers + ], +) +def test_spectrum_array_errors(kind, method, output, average, request): + """Test (Epochs)SpectrumArray constructor errors.""" + if method == "welch" and output == "power" and average: + spectrum = request.getfixturevalue(f"{kind}_spectrum") + else: + data = request.getfixturevalue(kind) + kwargs = dict() + if method == "welch": + kwargs.update(average=average) + spectrum = data.compute_psd(method=method, output=output, **kwargs) + data, freqs = spectrum.get_data(return_freqs=True) + info = spectrum.info + mt_weights = spectrum.mt_weights + Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray + # test mismatching number of channels + bad_n_chans = data[:-1] if kind == "raw" else data[:, :-1] with pytest.raises(ValueError, match=r"number of channels.*good data channels"): - EpochsSpectrumArray(data[:, :-1], info, freqs) - with pytest.raises(ValueError, match=r"last dimension.*same number of elements"): - EpochsSpectrumArray(data[..., :-1], info, freqs) + Klass(bad_n_chans, info, freqs) + # test mismatching number of frequencies + bad_n_freqs = ( + data[..., :-1, :] if method == "welch" and not average else data[..., :-1] + ) + with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"): + Klass(bad_n_freqs, info, freqs, method=method, mt_weights=mt_weights) # test mismatching events shape - n_epo = data.shape[0] + 1 # +1 so they purposely don't match - events = np.vstack( - (np.arange(n_epo), np.zeros(n_epo, dtype=int), np.ones(n_epo, dtype=int)) - ).T - with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): - EpochsSpectrumArray(data, info, freqs, events) + if kind == "epochs": + n_epo = data.shape[0] + 1 # +1 so they purposely don't match + events = np.vstack( + (np.arange(n_epo), np.zeros(n_epo, dtype=int), np.ones(n_epo, dtype=int)) + ).T + with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): + Klass(data, info, freqs, events) + # test unspecified method for unaggregated spectra (i.e. with segments or tapers) + if ( + method == "welch" + and not average + or method == "multitaper" + and output == "complex" + ): + with pytest.raises( + ValueError, match="Invalid value for the 'method' parameter" + ): + Klass(data, info, freqs, method="unknown", mt_weights=mt_weights) + # test unspecified/mismatched multitaper weights + if method == "multitaper" and output == "complex": + with pytest.raises( + ValueError, match=r"Expected size of `mt_weights` to be.*, got" + ): + Klass(data, info, freqs, method=method, mt_weights=None) + with pytest.raises( + ValueError, match=r"Expected size of `mt_weights` to be.*, got" + ): + Klass(data, info, freqs, method=method, mt_weights=mt_weights[:, :-1]) @pytest.mark.parametrize("kind", ("raw", "epochs")) -def test_spectrum_array(kind, tmp_path, request): +@pytest.mark.parametrize( + "method, output, average", + [ + ("welch", "power", "mean"), # test with precomputed spectrum + ("welch", "power", False), + ("welch", "complex", False), + ("welch", "complex", "mean"), + ("multitaper", "complex", None), + ], +) +def test_spectrum_array(kind, method, output, average, tmp_path, request): """Test EpochsSpectrumArray and SpectrumArray constructors.""" - spectrum = request.getfixturevalue(f"{kind}_spectrum") + if method == "welch" and output == "power" and average: + spectrum = request.getfixturevalue(f"{kind}_spectrum") + else: + data = request.getfixturevalue(kind) + kwargs = dict() + if method == "welch": + kwargs.update(average=average) + spectrum = data.compute_psd(method=method, output=output, **kwargs) data, freqs = spectrum.get_data(return_freqs=True) Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray - spect_arr = Klass(data=data, info=spectrum.info, freqs=freqs) + spect_arr = Klass( + data=data, + info=spectrum.info, + freqs=freqs, + method=method, + mt_weights=spectrum.mt_weights, + ) _check_spectrum_equivalent(spectrum, spect_arr, tmp_path) @pytest.mark.parametrize("kind", ("raw", "epochs")) @pytest.mark.parametrize("array", (False, True)) -def test_plot_spectrum(kind, array, request): +@pytest.mark.parametrize( + "method, output, average", + [ + ("welch", "power", "mean"), # test with precomputed spectrum + ("welch", "power", False), + ("welch", "complex", False), + ("welch", "complex", "mean"), + ("multitaper", "complex", None), + ], +) +def test_plot_spectrum(kind, array, method, output, average, request): """Test plotting (Epochs)Spectrum(Array).""" - spectrum = request.getfixturevalue(f"{kind}_spectrum") + if method == "welch" and output == "power" and average: + spectrum = request.getfixturevalue(f"{kind}_spectrum") + else: + data = request.getfixturevalue(kind) + kwargs = dict() + if method == "welch": + kwargs.update(average=average) + spectrum = data.compute_psd(method=method, output=output, **kwargs) if array: data, freqs = spectrum.get_data(return_freqs=True) Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray - spectrum = Klass(data=data, info=spectrum.info, freqs=freqs) + spectrum = Klass( + data=data, + info=spectrum.info, + freqs=freqs, + method=spectrum.method, + mt_weights=spectrum.mt_weights, + ) spectrum.info["bads"] = spectrum.ch_names[:1] # one grad channel spectrum.plot(average=True, amplitude=True, spatial_colors=True) spectrum.plot(average=True, amplitude=False, spatial_colors=False) From 4fa8466ce6d7b5dddfaa639915dde9e1c4bac8df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jul 2024 16:31:30 +0000 Subject: [PATCH 02/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/time_frequency/spectrum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 7f4c2979710..b6a233b1679 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -43,7 +43,7 @@ _is_numeric, check_fname, ) -from ..utils.misc import _identity_function, _pl +from ..utils.misc import _pl from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap From 891902ce23cb289a063c200a3bffb02625c3a1e6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 25 Jul 2024 12:36:46 +0200 Subject: [PATCH 03/16] Try fix failing unaggr dataframe test --- mne/time_frequency/tests/test_spectrum.py | 26 +++++++++++++---------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 00194a86a71..37ed2607f41 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -282,16 +282,13 @@ def test_spectrum_to_data_frame(inst, request, evoked): assert_frame_equal(_pick_first, _pick_last) -def _complex_helper(df, weights, group_cols): - """Convert complex spectrum to power after conversion to DataFrame.""" +def _agg_helper(df, weights, group_cols): + """Aggregate complex multitaper spectrum after conversion to DataFrame.""" from pandas import Series unagged_columns = df[group_cols].iloc[0].values.tolist() - x = df.drop(columns=group_cols).values[np.newaxis].T - if weights is None: - psd = np.mean((x * x.conj()).real * 2, axis=1) - else: - psd = _psd_from_mt(x, weights) + x_mt = df.drop(columns=group_cols).values[np.newaxis].T + psd = _psd_from_mt(x_mt, weights) psd = np.atleast_1d(np.squeeze(psd)).tolist() _df = dict(zip(df.columns, unagged_columns + psd)) return Series(_df) @@ -334,11 +331,18 @@ def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output): # aggregate df = df.drop(columns=drop_cols) gb = df.groupby(grouping_cols, as_index=False, observed=False) - if output == "complex": - gb = gb[df.columns] # https://github.com/pandas-dev/pandas/pull/52477 - agg_df = gb.apply(_complex_helper, spectrum.mt_weights, grouping_cols) + if method == "welch": + if output == "complex": + + def _fun(x): + return np.nanmean(np.real(x * np.conj(x))) + + agg_df = gb.agg(_fun) + else: + agg_df = gb.mean() # excludes missing values itself else: - agg_df = gb.mean() # excludes missing values itself + gb = gb[df.columns] # https://github.com/pandas-dev/pandas/pull/52477 + agg_df = gb.apply(_agg_helper, spectrum.mt_weights, grouping_cols) # even with check_categorical=False, we know that the *data* matches; # what may differ is the order of the "levels" in the *metadata* for the # channel name column From 52a25c0a1415d32fba113e349f346dc36fb7b38e Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 25 Jul 2024 12:50:27 -0400 Subject: [PATCH 04/16] FIX: Test --- mne/time_frequency/tests/test_spectrum.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 37ed2607f41..2ed9595954e 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -346,7 +346,10 @@ def _fun(x): # even with check_categorical=False, we know that the *data* matches; # what may differ is the order of the "levels" in the *metadata* for the # channel name column - assert_frame_equal(agg_df, orig_df, check_categorical=False) + agg_df.sort_values(by=grouping_cols, ignore_index=True, inplace=True) + orig_df.sort_values(by=grouping_cols, ignore_index=True, inplace=True) + # One can have categorical dtype and the other plain object, so don't check that + assert_frame_equal(agg_df, orig_df, check_categorical=False, check_dtype=False) # not testing with Evoked because it already has projs applied From 62fbcb7c9caf801dc540e4755c1982223755e1f0 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 26 Jul 2024 22:07:32 +0200 Subject: [PATCH 05/16] Apply suggestions from code review Co-authored-by: Daniel McCloy --- mne/time_frequency/spectrum.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index b6a233b1679..1e7fdbfaf10 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -391,7 +391,7 @@ def __setstate__(self, state): self.info = Info(**state["info"]) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` - self._mt_weights = state.get("mt_weights") # objs saved before #XXX won't have + self._mt_weights = state.get("mt_weights") # objs saved before #12747 won't have self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) @@ -440,7 +440,7 @@ def _check_values(self): warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) def _returns_complex_tapers(self, **method_kw): - return self.method == "multitaper" and method_kw.get("output", "") == "complex" + return self.method == "multitaper" and method_kw.get("output") == "complex" def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): # make the spectra @@ -462,7 +462,7 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): # (and then deleted afterwards) self._shape = (len(self.ch_names), len(self.freqs)) # append n_welch_segments - if method_kw.get("average", "") in (None, False): + if method_kw.get("average") in (None, False): n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw) self._shape += (n_welch_segments,) # insert n_tapers From 3b9d800b43e3a47cb0851bc0b23b4a24440fcfe3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Jul 2024 20:07:52 +0000 Subject: [PATCH 06/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/time_frequency/spectrum.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 1e7fdbfaf10..c3a8bcbcf83 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -391,7 +391,9 @@ def __setstate__(self, state): self.info = Info(**state["info"]) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` - self._mt_weights = state.get("mt_weights") # objs saved before #12747 won't have + self._mt_weights = state.get( + "mt_weights" + ) # objs saved before #12747 won't have self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) From 361db716396adfbe4ea0a2f9b837bc7ccd132dd4 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 26 Jul 2024 23:06:31 +0200 Subject: [PATCH 07/16] Update (partly) from review --- mne/time_frequency/spectrum.py | 13 +++--- mne/time_frequency/tests/test_spectrum.py | 50 ++++++++++++----------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index c3a8bcbcf83..255ca9be312 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -350,8 +350,9 @@ def __init__( self._dims = self._dims[:-1] + ("taper",) + self._dims[-1:] # record data type (for repr and html_repr) self._data_type = ( - f"{'Complex' if method_kw.get('output', '') == 'complex' else 'Real'} " - "Spectrum" + "Fourier Coefficients" + if method_kw.get("output") == "complex" + else "Power Spectrum" ) # set nave (child constructor overrides this for Evoked input) self._nave = None @@ -391,9 +392,7 @@ def __setstate__(self, state): self.info = Info(**state["info"]) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` - self._mt_weights = state.get( - "mt_weights" - ) # objs saved before #12747 won't have + self._mt_weights = state.get("mt_weights") # objs before #12747 won't have self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) @@ -463,8 +462,8 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): # this is *expected* shape, it gets asserted later in _check_values() # (and then deleted afterwards) self._shape = (len(self.ch_names), len(self.freqs)) - # append n_welch_segments - if method_kw.get("average") in (None, False): + # append n_welch_segments (use "" as .get() default since None considered valid) + if method_kw.get("average", "") in (None, False): n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw) self._shape += (n_welch_segments,) # insert n_tapers diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 2ed9595954e..f63c1e74ba9 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -70,8 +70,8 @@ def test_compute_psd_errors(raw): True, False, "full", - "power", # XXX: technically a default - ], # non-defaults + "power", + ], # non-defaults (excluding output) [ 0, np.inf, @@ -89,7 +89,7 @@ def test_compute_psd_errors(raw): True, "length", "complex", - ], # complex XXX: need to also test with non-defaults? + ], # complex (testing with non-defaults doesn't increase coverage) ], ) def test_spectrum_params( @@ -335,13 +335,13 @@ def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output): if output == "complex": def _fun(x): - return np.nanmean(np.real(x * np.conj(x))) + return np.mean(np.real(x * np.conj(x))) # use mean to aggregate agg_df = gb.agg(_fun) else: agg_df = gb.mean() # excludes missing values itself else: - gb = gb[df.columns] # https://github.com/pandas-dev/pandas/pull/52477 + gb = gb[df.columns] # XXX: try removing when minimum pandas >= 2.1 is required agg_df = gb.apply(_agg_helper, spectrum.mt_weights, grouping_cols) # even with check_categorical=False, we know that the *data* matches; # what may differ is the order of the "levels" in the *metadata* for the @@ -439,7 +439,6 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path): assert_array_equal(spect1.freqs, spect2.freqs) -@pytest.mark.parametrize("kind", ("raw", "epochs")) @pytest.mark.parametrize( "method, output, average", [ @@ -448,12 +447,12 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path): ("multitaper", "complex", None), # unaggregated tapers ], ) -def test_spectrum_array_errors(kind, method, output, average, request): - """Test (Epochs)SpectrumArray constructor errors.""" +def test_spectrum_array_errors(method, output, average, request): + """Test EpochsSpectrumArray constructor errors.""" if method == "welch" and output == "power" and average: - spectrum = request.getfixturevalue(f"{kind}_spectrum") + spectrum = request.getfixturevalue("epochs_spectrum") else: - data = request.getfixturevalue(kind) + data = request.getfixturevalue("epochs") kwargs = dict() if method == "welch": kwargs.update(average=average) @@ -461,25 +460,24 @@ def test_spectrum_array_errors(kind, method, output, average, request): data, freqs = spectrum.get_data(return_freqs=True) info = spectrum.info mt_weights = spectrum.mt_weights - Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray # test mismatching number of channels - bad_n_chans = data[:-1] if kind == "raw" else data[:, :-1] with pytest.raises(ValueError, match=r"number of channels.*good data channels"): - Klass(bad_n_chans, info, freqs) + EpochsSpectrumArray(data[:, :-1], info, freqs) # test mismatching number of frequencies bad_n_freqs = ( data[..., :-1, :] if method == "welch" and not average else data[..., :-1] ) with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"): - Klass(bad_n_freqs, info, freqs, method=method, mt_weights=mt_weights) + EpochsSpectrumArray( + bad_n_freqs, info, freqs, method=method, mt_weights=mt_weights + ) # test mismatching events shape - if kind == "epochs": - n_epo = data.shape[0] + 1 # +1 so they purposely don't match - events = np.vstack( - (np.arange(n_epo), np.zeros(n_epo, dtype=int), np.ones(n_epo, dtype=int)) - ).T - with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): - Klass(data, info, freqs, events) + n_epo = data.shape[0] + 1 # +1 so they purposely don't match + events = np.vstack( + (np.arange(n_epo), np.zeros(n_epo, dtype=int), np.ones(n_epo, dtype=int)) + ).T + with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): + EpochsSpectrumArray(data, info, freqs, events) # test unspecified method for unaggregated spectra (i.e. with segments or tapers) if ( method == "welch" @@ -490,17 +488,21 @@ def test_spectrum_array_errors(kind, method, output, average, request): with pytest.raises( ValueError, match="Invalid value for the 'method' parameter" ): - Klass(data, info, freqs, method="unknown", mt_weights=mt_weights) + EpochsSpectrumArray( + data, info, freqs, method="unknown", mt_weights=mt_weights + ) # test unspecified/mismatched multitaper weights if method == "multitaper" and output == "complex": with pytest.raises( ValueError, match=r"Expected size of `mt_weights` to be.*, got" ): - Klass(data, info, freqs, method=method, mt_weights=None) + EpochsSpectrumArray(data, info, freqs, method=method, mt_weights=None) with pytest.raises( ValueError, match=r"Expected size of `mt_weights` to be.*, got" ): - Klass(data, info, freqs, method=method, mt_weights=mt_weights[:, :-1]) + EpochsSpectrumArray( + data, info, freqs, method=method, mt_weights=mt_weights[:, :-1] + ) @pytest.mark.parametrize("kind", ("raw", "epochs")) From 322ae92f4879d24a31d9ddbd1c504c594dbd1966 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Sat, 27 Jul 2024 17:55:05 +0200 Subject: [PATCH 08/16] Update (remaining) from review --- mne/time_frequency/spectrum.py | 198 ++++++++++---------- mne/time_frequency/tests/test_spectrum.py | 210 ++++++++++++---------- 2 files changed, 223 insertions(+), 185 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 255ca9be312..ccf30bd3186 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -374,7 +374,7 @@ def __getstate__(self): data_type=self._data_type, info=self.info, nave=self.nave, - mt_weights=self.mt_weights, + weights=self.weights, ) return out @@ -392,7 +392,7 @@ def __setstate__(self, state): self.info = Info(**state["info"]) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` - self._mt_weights = state.get("mt_weights") # objs before #12747 won't have + self._weights = state.get("weights") # objs saved before #12747 won't have self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) @@ -452,11 +452,11 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): if self._returns_complex_tapers(**method_kw): fourier_coefs, freqs, weights = result self._data = fourier_coefs - self._mt_weights = weights + self._weights = weights else: psds, freqs = result self._data = psds - self._mt_weights = None + self._weights = None # assign properties (._data already assigned above) self._freqs = freqs # this is *expected* shape, it gets asserted later in _check_values() @@ -468,7 +468,7 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): self._shape += (n_welch_segments,) # insert n_tapers if self._returns_complex_tapers(**method_kw): - self._shape = self._shape[:-1] + (self._mt_weights.size,) + self._shape[-1:] + self._shape = self._shape[:-1] + (self._weights.size,) + self._shape[-1:] # we don't need these anymore, and they make save/load harder del self._picks del self._psd_func @@ -500,8 +500,8 @@ def nave(self): return self._nave @property - def mt_weights(self): - return self._mt_weights + def weights(self): + return self._weights @property def sfreq(self): @@ -890,13 +890,12 @@ def plot_topomap( def _prepare_data_for_plot(self, data): # handle unaggregated Welch if "segment" in self._dims: - seg_axis = self._dims.index("segment") logger.info("Aggregating Welch estimates (median) before plotting...") - data = np.nanmedian(data, axis=seg_axis) + data = np.nanmedian(data, axis=self._dims.index("segment")) # handle unaggregated multitaper (also handles complex -> power) elif "taper" in self._dims: logger.info("Aggregating multitaper estimates before plotting...") - data = _psd_from_mt(data, self.mt_weights) + data = _psd_from_mt(data, self.weights) # handle complex data (should only be Welch remaining) if np.iscomplexobj(data): @@ -1090,12 +1089,12 @@ class Spectrum(BaseSpectrum): Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s - method : str - The method used to compute the spectrum (``'welch'`` or ``'multitaper'``). + method : "welch" | "multitaper" + The method used to compute the spectrum. nave : int | None The number of trials averaged together when generating the spectrum. ``None`` indicates no averaging is known to have occurred. - mt_weights : array | None + weights : array | None The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. @@ -1215,55 +1214,62 @@ def __getitem__(self, item): return BaseRaw._getitem(self, item, return_times=False) -def _check_data_shape( - data, info, freqs, is_epoched, events, method, standard_ndim, extra_ndim, mt_weights -): - _check_option("data.ndim", data.ndim, (standard_ndim, extra_ndim)) +def _check_data_shape(data, info, freqs, dimnames, weights): + if data.ndim != len(dimnames): + raise ValueError( + f"Expected data to have {len(dimnames)} dimensions, got {data.ndim}." + ) - dims = () - if is_epoched: - dims += ("epoch",) - if events is not None and data.shape[0] != events.shape[0]: - raise ValueError( - f"The first dimension of `data` ({data.shape[0]}) must match the first " - f"dimension of `events` ({events.shape[0]})." - ) + is_epoched = 1 if "epoch" in dimnames else 0 + allowed_dims = ["epoch", "channel", "freq", "segment", "taper"] + allowed_dims = allowed_dims[0 if is_epoched else 1 :] + if set(allowed_dims).intersection(dimnames) != set(dimnames): + raise ValueError( + f"All entries of `dimnames` must be in {allowed_dims}, got {dimnames}." + ) + if "channel" not in dimnames or "freq" not in dimnames: + raise ValueError("Both 'channel' and 'freq' must be present in `dimnames`.") - dims += ("channel",) + if list(dimnames).index("channel") != is_epoched: + raise ValueError( + f"'channel' must be the {'second' if is_epoched else 'first'} dimension of " + "the data." + ) want_n_chan = _pick_data_channels(info).size - got_n_chan = data.shape[list(dims).index("channel")] + got_n_chan = data.shape[list(dimnames).index("channel")] if got_n_chan != want_n_chan: raise ValueError( f"The number of channels in `data` ({got_n_chan}) must match the number of " f"good data channels in `info` ({want_n_chan})." ) - if data.ndim == extra_ndim: # i.e. segments or tapers present - _check_option( - "method", method, ("multitaper", "welch"), f" when data.ndim={extra_ndim}" - ) # require method specified to differentiate segments from tapers - if method == "multitaper": - actual = None if mt_weights is None else mt_weights.size - expected = data.shape[-2] - if actual != expected: - raise ValueError( - f"Expected size of `mt_weights` to be {expected}, got {actual}." - ) - dims += ("taper", "freq") - else: # i.e. welch - dims += ("freq", "segment") - else: - dims += ("freq",) + # given we limit max array size and ensure channel & freq dims present, only one of + # taper or segment can be present + if "taper" in dimnames: + if dimnames[-2] != "taper": # _psd_from_mt assumes this (called when plotting) + raise ValueError( + "'taper' must be the second to last dimension of the data." + ) + # expect weights for each taper + actual = None if weights is None else weights.size + expected = data.shape[list(dimnames).index("taper")] + if actual != expected: + raise ValueError( + f"Expected size of `weights` to be {expected} to match `data`, got " + f"{actual}." + ) + elif "segment" in dimnames and dimnames[-1] != "segment": + raise ValueError("'segment' must be the last dimension of the data.") + + # freq being in wrong position ruled out by above checks want_n_freq = freqs.size - got_n_freq = data.shape[list(dims).index("freq")] + got_n_freq = data.shape[list(dimnames).index("freq")] if got_n_freq != want_n_freq: raise ValueError( f"The number of frequencies in `data` ({got_n_freq}) must match the number " f"of elements in `freqs` ({want_n_freq})." ) - return dims - @fill_doc class SpectrumArray(Spectrum): @@ -1275,11 +1281,10 @@ class SpectrumArray(Spectrum): The spectra for each channel. %(info_not_none)s %(freqs_tfr_array)s - method : str - The spectral estimation method. Default ``'unknown'``. Must be provided if data - contains unaggregated tapers (``method='multitaper'``) or segments - (``method='welch'``). - mt_weights : ndarray | None + dimnames : tuple of str + The name of the dimensions in the data. Must contain ``'channel'`` and + ``'freq'``. Can also contain one of ``'taper'`` or ``'segment'``. + weights : ndarray | None The multitaper weights used for averaging across tapers. Only required if data from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``. %(verbose)s @@ -1304,34 +1309,37 @@ def __init__( data, info, freqs, - method="unknown", - mt_weights=None, + dimnames=("channel", "freq"), + weights=None, *, verbose=None, ): - dims = _check_data_shape( - data=data, - info=info, - freqs=freqs, - is_epoched=False, - events=None, - method=method, - standard_ndim=2, - extra_ndim=3, - mt_weights=mt_weights, - ) + # (channel, [taper], freq, [segment]) + _check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension + + if "epoch" in dimnames: + raise ValueError( + "'`data` must not be epoched. Use mne.time_frequency." + "EpochsSpectrumArray for storing epoched spectral data." + ) + + _check_data_shape(data, info, freqs, dimnames, weights) self.__setstate__( dict( - method=method, + method="unknown", data=data, sfreq=info["sfreq"], - dims=dims, + dims=dimnames, freqs=freqs, inst_type_str="Array", - data_type=f"{'Complex' if np.iscomplexobj(data) else 'Real'} Spectrum", + data_type=( + "Fourier Coefficients" + if np.iscomplexobj(data) + else "Power Spectrum" + ), info=info, - mt_weights=mt_weights, + weights=weights, ) ) @@ -1367,9 +1375,9 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s - method : str - The method used to compute the spectrum (``'welch'`` or ``'multitaper'``). - mt_weights : array | None + method : "welch" | "multitaper" + The method used to compute the spectrum. + weights : array | None The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. @@ -1551,11 +1559,10 @@ class EpochsSpectrumArray(EpochsSpectrum): %(freqs_tfr_array)s %(events_epochs)s %(event_id)s - method : str - The spectral estimation method. Default ``'unknown'``. Must be provided if data - contains unaggregated tapers (``method='multitaper'``) or segments - (``method='welch'``). - mt_weights : ndarray | None + dimnames : tuple of str + The name of the dimensions in the data. Must contain ``'epoch'``, ``'channel'``, + and ``'freq'``. Can also contain one of ``'taper'`` or ``'segment'``. + weights : ndarray | None The multitaper weights used for averaging across tapers. Only required if data from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``. %(verbose)s @@ -1581,39 +1588,44 @@ def __init__( freqs, events=None, event_id=None, - method="unknown", - mt_weights=None, + dimnames=("epoch", "channel", "freq"), + weights=None, *, verbose=None, ): - dims = _check_data_shape( - data=data, - info=info, - freqs=freqs, - is_epoched=True, - events=events, - method=method, - standard_ndim=3, - extra_ndim=4, - mt_weights=mt_weights, - ) + # (epoch, channel, [taper], freq, [segment]) + _check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension + + if "epoch" not in dimnames or list(dimnames).index("epoch") != 0: + raise ValueError("'epoch' must be the first dimension of `data`.") + if events is not None and data.shape[0] != events.shape[0]: + raise ValueError( + f"The first dimension of `data` ({data.shape[0]}) must match the first " + f"dimension of `events` ({events.shape[0]})." + ) + + _check_data_shape(data, info, freqs, dimnames, weights) self.__setstate__( dict( - method=method, + method="unknown", data=data, sfreq=info["sfreq"], - dims=dims, + dims=dimnames, freqs=freqs, inst_type_str="Array", - data_type=f"{'Complex' if np.iscomplexobj(data) else 'Real'} Spectrum", + data_type=( + "Fourier Coefficients" + if np.iscomplexobj(data) + else "Power Spectrum" + ), info=info, events=events, event_id=event_id, metadata=None, selection=np.arange(data.shape[0]), drop_log=tuple(tuple() for _ in range(data.shape[0])), - mt_weights=mt_weights, + weights=weights, ) ) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index f63c1e74ba9..1303c3e88c7 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -342,7 +342,7 @@ def _fun(x): agg_df = gb.mean() # excludes missing values itself else: gb = gb[df.columns] # XXX: try removing when minimum pandas >= 2.1 is required - agg_df = gb.apply(_agg_helper, spectrum.mt_weights, grouping_cols) + agg_df = gb.apply(_agg_helper, spectrum.weights, grouping_cols) # even with check_categorical=False, we know that the *data* matches; # what may differ is the order of the "levels" in the *metadata* for the # channel name column @@ -439,85 +439,122 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path): assert_array_equal(spect1.freqs, spect2.freqs) -@pytest.mark.parametrize( - "method, output, average", - [ - ("welch", "power", "mean"), # test with precomputed spectrum - ("welch", "power", False), # unaggregated segments - ("multitaper", "complex", None), # unaggregated tapers - ], -) -def test_spectrum_array_errors(method, output, average, request): - """Test EpochsSpectrumArray constructor errors.""" - if method == "welch" and output == "power" and average: - spectrum = request.getfixturevalue("epochs_spectrum") - else: - data = request.getfixturevalue("epochs") - kwargs = dict() - if method == "welch": - kwargs.update(average=average) - spectrum = data.compute_psd(method=method, output=output, **kwargs) - data, freqs = spectrum.get_data(return_freqs=True) - info = spectrum.info - mt_weights = spectrum.mt_weights - # test mismatching number of channels - with pytest.raises(ValueError, match=r"number of channels.*good data channels"): - EpochsSpectrumArray(data[:, :-1], info, freqs) - # test mismatching number of frequencies - bad_n_freqs = ( - data[..., :-1, :] if method == "welch" and not average else data[..., :-1] - ) - with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"): +def _get_dimnames(kind, method, output, average): + dimnames = ("epoch", "channel") if kind == "epochs" else ("channel",) + if method == "welch": + dimnames += ("freq",) if average else ("freq", "segment") + else: # i.e. multitaper + dimnames += ("freq",) if output == "power" else ("taper", "freq") + return dimnames + + +def test_spectrum_array_errors(): + """Test (Epochs)SpectrumArray constructor errors.""" + n_epochs = 10 + n_chans = 5 + n_freqs = 50 + freqs = np.arange(n_freqs) + sfreq = 100 + rng = np.random.default_rng(44) + data = rng.random((n_epochs, n_chans, n_freqs)) + dimnames = ("epoch", "channel", "freq") + info = create_info(n_chans, sfreq, "eeg") + # test incorrect ndims (for SpectrumArray; allows 2-3D data) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): + SpectrumArray(data[0, 0, :], info, freqs, dimnames=dimnames) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): + SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames) + # test incorrect ndims (for EpochsSpectrumArray; allows 3-4D data) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): + EpochsSpectrumArray(data[0, :, :], info, freqs, dimnames=dimnames) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): EpochsSpectrumArray( - bad_n_freqs, info, freqs, method=method, mt_weights=mt_weights + np.expand_dims(data, axis=(3, 4)), info, freqs, dimnames=dimnames ) + # test incorrect epochs location + with pytest.raises(ValueError, match="'epoch' must be the first dimension"): + EpochsSpectrumArray(data, info, freqs, dimnames=("channel", "epoch", "freq")) # test mismatching events shape - n_epo = data.shape[0] + 1 # +1 so they purposely don't match events = np.vstack( - (np.arange(n_epo), np.zeros(n_epo, dtype=int), np.ones(n_epo, dtype=int)) + ( + np.arange(n_epochs + 1), + np.zeros(n_epochs + 1, dtype=int), + np.ones(n_epochs + 1, dtype=int), + ) ).T with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): - EpochsSpectrumArray(data, info, freqs, events) - # test unspecified method for unaggregated spectra (i.e. with segments or tapers) - if ( - method == "welch" - and not average - or method == "multitaper" - and output == "complex" - ): - with pytest.raises( - ValueError, match="Invalid value for the 'method' parameter" - ): - EpochsSpectrumArray( - data, info, freqs, method="unknown", mt_weights=mt_weights - ) - # test unspecified/mismatched multitaper weights - if method == "multitaper" and output == "complex": - with pytest.raises( - ValueError, match=r"Expected size of `mt_weights` to be.*, got" - ): - EpochsSpectrumArray(data, info, freqs, method=method, mt_weights=None) - with pytest.raises( - ValueError, match=r"Expected size of `mt_weights` to be.*, got" - ): - EpochsSpectrumArray( - data, info, freqs, method=method, mt_weights=mt_weights[:, :-1] - ) - - -@pytest.mark.parametrize("kind", ("raw", "epochs")) + EpochsSpectrumArray(data, info, freqs, events, dimnames=dimnames) + # test data-dimname mismatch + with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"): + EpochsSpectrumArray(data, info, freqs, dimnames=dimnames[:-1]) + # test unrecognised dimnames (for SpectrumArray; epoch not allowed) + with pytest.raises(ValueError, match="`data` must not be epoched"): + SpectrumArray(data[0, :, :], info, freqs, dimnames=("epoch", "channel")) + # test unrecognised dimnames (for EpochsSpectrumArray) + with pytest.raises(ValueError, match=r"entries of `dimnames` must be in.*, got,*"): + EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "notfreq")) + # test missing dimnames + with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): + EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "channel")) + with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): + EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "freq")) + with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): + EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "epoch", "epoch")) + # test incorrect channel location (for SpectrumArray; must be 1st dim) + with pytest.raises(ValueError, match="'channel' must be the first dimension"): + SpectrumArray(data[0, :, :], info, freqs, dimnames=("freq", "channel")) + # test incorrect channel location (for EpochsSpectrumArray; must be 2nd dim) + with pytest.raises(ValueError, match="'channel' must be the second dimension"): + EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "channel")) + # test mismatching number of channels + with pytest.raises(ValueError, match=r"number of channels.*good data channels"): + EpochsSpectrumArray(data[:, :-1, :], info, freqs, dimnames=dimnames) + # test incorrect taper position + with pytest.raises(ValueError, match="'taper' must be the second to last dim"): + EpochsSpectrumArray( + np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames + ("taper",) + ) + # test incorrect weight size + with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): + EpochsSpectrumArray( + np.expand_dims(data, axis=2), + info, + freqs, + dimnames=("epoch", "channel", "taper", "freq"), + weights=None, + ) + with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): + EpochsSpectrumArray( + np.expand_dims(data, axis=2), + info, + freqs, + dimnames=("epoch", "channel", "taper", "freq"), + weights=np.ones((1, 2, 1)), + ) + # test incorrect segment position + with pytest.raises(ValueError, match="'segment' must be the last dim"): + EpochsSpectrumArray( + np.expand_dims(data, axis=2), + info, + freqs, + dimnames=("epoch", "channel", "segment", "freq"), + ) + # test mismatching number of frequencies + with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"): + EpochsSpectrumArray(data[:, :, :-1], info, freqs, dimnames=dimnames) + + @pytest.mark.parametrize( - "method, output, average", + "kind, method, output, average", [ - ("welch", "power", "mean"), # test with precomputed spectrum - ("welch", "power", False), - ("welch", "complex", False), - ("welch", "complex", "mean"), - ("multitaper", "complex", None), - ], + ("raw", "welch", "power", "mean"), # test with precomputed spectrum + ("epochs", "welch", "power", False), # test with segments + ("epochs", "multitaper", "complex", None), # test with tapers + ], # additional variants don't improve coverage ) def test_spectrum_array(kind, method, output, average, tmp_path, request): """Test EpochsSpectrumArray and SpectrumArray constructors.""" + dimnames = _get_dimnames(kind, method, output, average) if method == "welch" and output == "power" and average: spectrum = request.getfixturevalue(f"{kind}_spectrum") else: @@ -532,44 +569,33 @@ def test_spectrum_array(kind, method, output, average, tmp_path, request): data=data, info=spectrum.info, freqs=freqs, - method=method, - mt_weights=spectrum.mt_weights, + dimnames=dimnames, + weights=spectrum.weights, ) _check_spectrum_equivalent(spectrum, spect_arr, tmp_path) -@pytest.mark.parametrize("kind", ("raw", "epochs")) -@pytest.mark.parametrize("array", (False, True)) @pytest.mark.parametrize( "method, output, average", [ ("welch", "power", "mean"), # test with precomputed spectrum - ("welch", "power", False), - ("welch", "complex", False), - ("welch", "complex", "mean"), - ("multitaper", "complex", None), - ], + ("welch", "complex", False), # test aggr over segments & conversion to power + ("multitaper", "complex", None), # test aggr over tapers & conversion to power + ], # additional variants don't improve coverage ) -def test_plot_spectrum(kind, array, method, output, average, request): - """Test plotting (Epochs)Spectrum(Array).""" +def test_plot_spectrum(method, output, average, request): + """Test plotting EpochsSpectrum(Array). + + Testing Spectrum(Array) with raw data doesn't improve coverage. + """ if method == "welch" and output == "power" and average: - spectrum = request.getfixturevalue(f"{kind}_spectrum") + spectrum = request.getfixturevalue("epochs_spectrum") else: - data = request.getfixturevalue(kind) + data = request.getfixturevalue("epochs") kwargs = dict() if method == "welch": kwargs.update(average=average) spectrum = data.compute_psd(method=method, output=output, **kwargs) - if array: - data, freqs = spectrum.get_data(return_freqs=True) - Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray - spectrum = Klass( - data=data, - info=spectrum.info, - freqs=freqs, - method=spectrum.method, - mt_weights=spectrum.mt_weights, - ) spectrum.info["bads"] = spectrum.ch_names[:1] # one grad channel spectrum.plot(average=True, amplitude=True, spatial_colors=True) spectrum.plot(average=True, amplitude=False, spatial_colors=False) From 43ad99277b7b58f496e161d9b3886024abe8deee Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 29 Jul 2024 12:06:21 +0200 Subject: [PATCH 09/16] Add changelog entry Co-Authored-By: Daniel McCloy <1810515+drammock@users.noreply.github.com> Co-Authored-By: Alex Rockhill Co-Authored-By: Eric Larson --- doc/changes/devel/12747.newfeature.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 doc/changes/devel/12747.newfeature.rst diff --git a/doc/changes/devel/12747.newfeature.rst b/doc/changes/devel/12747.newfeature.rst new file mode 100644 index 00000000000..2957117b778 --- /dev/null +++ b/doc/changes/devel/12747.newfeature.rst @@ -0,0 +1,3 @@ +Add support for storing Fourier coefficients in :class:`mne.time_frequency.Spectrum`, +:class:`mne.time_frequency.EpochsSpectrum`, :class:`mne.time_frequency.SpectrumArray`, +and :class:`mne.time_frequency.EpochsSpectrumArray` objects, by `Thomas Binns`_. \ No newline at end of file From 85ef415352c544fff29494fdb3289795a6695866 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 29 Jul 2024 16:46:13 +0200 Subject: [PATCH 10/16] Allow __getstate__ call from SpectrumArray --- mne/utils/spectrum.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index 92ed4170c83..4425616f93d 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -9,6 +9,8 @@ def _get_instance_type_string(inst): """Get string representation of the originating instance type.""" + from numpy import ndarray + from ..epochs import BaseEpochs from ..evoked import Evoked, EvokedArray from ..io import BaseRaw @@ -20,6 +22,8 @@ def _get_instance_type_string(inst): inst_type_str = "Epochs" elif inst._inst_type in (Evoked, EvokedArray): inst_type_str = "Evoked" + elif inst._inst_type == ndarray: + inst_type_str = "Array" else: raise RuntimeError( f"Unknown instance type {inst._inst_type} in {type(inst).__name__}" From 01f26c39c4db80ffc7fccaa58d9f6515f34aec08 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 30 Jul 2024 12:38:45 +0200 Subject: [PATCH 11/16] Update from review --- mne/time_frequency/spectrum.py | 51 +++++++++++------------ mne/time_frequency/tests/test_spectrum.py | 4 +- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index ccf30bd3186..e35942b2a88 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1214,23 +1214,22 @@ def __getitem__(self, item): return BaseRaw._getitem(self, item, return_times=False) -def _check_data_shape(data, info, freqs, dimnames, weights): +def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched): if data.ndim != len(dimnames): raise ValueError( f"Expected data to have {len(dimnames)} dimensions, got {data.ndim}." ) - is_epoched = 1 if "epoch" in dimnames else 0 allowed_dims = ["epoch", "channel", "freq", "segment", "taper"] - allowed_dims = allowed_dims[0 if is_epoched else 1 :] - if set(allowed_dims).intersection(dimnames) != set(dimnames): - raise ValueError( - f"All entries of `dimnames` must be in {allowed_dims}, got {dimnames}." - ) + if not is_epoched: + allowed_dims.remove("epoch") + # TODO maybe we should be nice and allow plural versions of each dimname? + for dim in dimnames: + _check_option("dimnames", dim, allowed_dims) if "channel" not in dimnames or "freq" not in dimnames: raise ValueError("Both 'channel' and 'freq' must be present in `dimnames`.") - if list(dimnames).index("channel") != is_epoched: + if list(dimnames).index("channel") != int(is_epoched): raise ValueError( f"'channel' must be the {'second' if is_epoched else 'first'} dimension of " "the data." @@ -1255,8 +1254,8 @@ def _check_data_shape(data, info, freqs, dimnames, weights): expected = data.shape[list(dimnames).index("taper")] if actual != expected: raise ValueError( - f"Expected size of `weights` to be {expected} to match `data`, got " - f"{actual}." + f"Expected size of `weights` to be {expected} to match 'n_tapers' in " + f"`data`, got {actual}." ) elif "segment" in dimnames and dimnames[-1] != "segment": raise ValueError("'segment' must be the last dimension of the data.") @@ -1282,11 +1281,13 @@ class SpectrumArray(Spectrum): %(info_not_none)s %(freqs_tfr_array)s dimnames : tuple of str - The name of the dimensions in the data. Must contain ``'channel'`` and - ``'freq'``. Can also contain one of ``'taper'`` or ``'segment'``. + The name of the dimensions in the data, in the order they occur. Must contain + ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include + either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., + multitaper algorithms) dimension. If including ``'taper'``, you should also pass + a ``weights`` parameter. weights : ndarray | None - The multitaper weights used for averaging across tapers. Only required if data - from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``. + Weights for the ``'taper'`` dimension, if present (see ``dimnames``). %(verbose)s See Also @@ -1317,13 +1318,7 @@ def __init__( # (channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension - if "epoch" in dimnames: - raise ValueError( - "'`data` must not be epoched. Use mne.time_frequency." - "EpochsSpectrumArray for storing epoched spectral data." - ) - - _check_data_shape(data, info, freqs, dimnames, weights) + _check_data_shape(data, info, freqs, dimnames, weights, is_epoched=False) self.__setstate__( dict( @@ -1560,11 +1555,13 @@ class EpochsSpectrumArray(EpochsSpectrum): %(events_epochs)s %(event_id)s dimnames : tuple of str - The name of the dimensions in the data. Must contain ``'epoch'``, ``'channel'``, - and ``'freq'``. Can also contain one of ``'taper'`` or ``'segment'``. + The name of the dimensions in the data, in the order they occur. Must contain + ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include + either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., + multitaper algorithms) dimension. If including ``'taper'``, you should also pass + a ``weights`` parameter. weights : ndarray | None - The multitaper weights used for averaging across tapers. Only required if data - from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``. + Weights for the ``'taper'`` dimension, if present (see ``dimnames``). %(verbose)s See Also @@ -1596,7 +1593,7 @@ def __init__( # (epoch, channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension - if "epoch" not in dimnames or list(dimnames).index("epoch") != 0: + if list(dimnames).index("epoch") != 0: raise ValueError("'epoch' must be the first dimension of `data`.") if events is not None and data.shape[0] != events.shape[0]: raise ValueError( @@ -1604,7 +1601,7 @@ def __init__( f"dimension of `events` ({events.shape[0]})." ) - _check_data_shape(data, info, freqs, dimnames, weights) + _check_data_shape(data, info, freqs, dimnames, weights, is_epoched=True) self.__setstate__( dict( diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 1303c3e88c7..d9e1917ce61 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -488,10 +488,10 @@ def test_spectrum_array_errors(): with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"): EpochsSpectrumArray(data, info, freqs, dimnames=dimnames[:-1]) # test unrecognised dimnames (for SpectrumArray; epoch not allowed) - with pytest.raises(ValueError, match="`data` must not be epoched"): + with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"): SpectrumArray(data[0, :, :], info, freqs, dimnames=("epoch", "channel")) # test unrecognised dimnames (for EpochsSpectrumArray) - with pytest.raises(ValueError, match=r"entries of `dimnames` must be in.*, got,*"): + with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"): EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "notfreq")) # test missing dimnames with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): From 7badfbf2f4543a968785e6de1d10c17e3dec2ef9 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 30 Jul 2024 19:54:55 +0200 Subject: [PATCH 12/16] Update (Epochs)SpectrumArray docstrings --- mne/time_frequency/spectrum.py | 62 +++++++++++-------- mne/time_frequency/tests/test_spectrum.py | 73 +++++++++++------------ mne/utils/docs.py | 12 ++-- 3 files changed, 81 insertions(+), 66 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index e35942b2a88..cc643490915 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1098,6 +1098,8 @@ class Spectrum(BaseSpectrum): The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. + .. versionadded:: 1.8 + See Also -------- EpochsSpectrum @@ -1214,28 +1216,28 @@ def __getitem__(self, item): return BaseRaw._getitem(self, item, return_times=False) -def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched): - if data.ndim != len(dimnames): +def _check_data_shape(data, info, freqs, dim_names, weights, is_epoched): + if data.ndim != len(dim_names): raise ValueError( - f"Expected data to have {len(dimnames)} dimensions, got {data.ndim}." + f"Expected data to have {len(dim_names)} dimensions, got {data.ndim}." ) allowed_dims = ["epoch", "channel", "freq", "segment", "taper"] if not is_epoched: allowed_dims.remove("epoch") # TODO maybe we should be nice and allow plural versions of each dimname? - for dim in dimnames: - _check_option("dimnames", dim, allowed_dims) - if "channel" not in dimnames or "freq" not in dimnames: - raise ValueError("Both 'channel' and 'freq' must be present in `dimnames`.") + for dim in dim_names: + _check_option("dim_names", dim, allowed_dims) + if "channel" not in dim_names or "freq" not in dim_names: + raise ValueError("Both 'channel' and 'freq' must be present in `dim_names`.") - if list(dimnames).index("channel") != int(is_epoched): + if list(dim_names).index("channel") != int(is_epoched): raise ValueError( f"'channel' must be the {'second' if is_epoched else 'first'} dimension of " "the data." ) want_n_chan = _pick_data_channels(info).size - got_n_chan = data.shape[list(dimnames).index("channel")] + got_n_chan = data.shape[list(dim_names).index("channel")] if got_n_chan != want_n_chan: raise ValueError( f"The number of channels in `data` ({got_n_chan}) must match the number of " @@ -1244,25 +1246,25 @@ def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched): # given we limit max array size and ensure channel & freq dims present, only one of # taper or segment can be present - if "taper" in dimnames: - if dimnames[-2] != "taper": # _psd_from_mt assumes this (called when plotting) + if "taper" in dim_names: + if dim_names[-2] != "taper": # _psd_from_mt assumes this (called when plotting) raise ValueError( "'taper' must be the second to last dimension of the data." ) # expect weights for each taper actual = None if weights is None else weights.size - expected = data.shape[list(dimnames).index("taper")] + expected = data.shape[list(dim_names).index("taper")] if actual != expected: raise ValueError( f"Expected size of `weights` to be {expected} to match 'n_tapers' in " f"`data`, got {actual}." ) - elif "segment" in dimnames and dimnames[-1] != "segment": + elif "segment" in dim_names and dim_names[-1] != "segment": raise ValueError("'segment' must be the last dimension of the data.") # freq being in wrong position ruled out by above checks want_n_freq = freqs.size - got_n_freq = data.shape[list(dimnames).index("freq")] + got_n_freq = data.shape[list(dim_names).index("freq")] if got_n_freq != want_n_freq: raise ValueError( f"The number of frequencies in `data` ({got_n_freq}) must match the number " @@ -1280,14 +1282,18 @@ class SpectrumArray(Spectrum): The spectra for each channel. %(info_not_none)s %(freqs_tfr_array)s - dimnames : tuple of str + dim_names : tuple of str The name of the dimensions in the data, in the order they occur. Must contain ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., multitaper algorithms) dimension. If including ``'taper'``, you should also pass a ``weights`` parameter. + + .. versionadded:: 1.8 weights : ndarray | None - Weights for the ``'taper'`` dimension, if present (see ``dimnames``). + Weights for the ``'taper'`` dimension, if present (see ``dim_names``). + + .. versionadded:: 1.8 %(verbose)s See Also @@ -1310,7 +1316,7 @@ def __init__( data, info, freqs, - dimnames=("channel", "freq"), + dim_names=("channel", "freq"), weights=None, *, verbose=None, @@ -1318,14 +1324,14 @@ def __init__( # (channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension - _check_data_shape(data, info, freqs, dimnames, weights, is_epoched=False) + _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=False) self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], - dims=dimnames, + dims=dim_names, freqs=freqs, inst_type_str="Array", data_type=( @@ -1376,6 +1382,8 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. + .. versionadded:: 1.8 + See Also -------- EpochsSpectrumArray @@ -1554,14 +1562,18 @@ class EpochsSpectrumArray(EpochsSpectrum): %(freqs_tfr_array)s %(events_epochs)s %(event_id)s - dimnames : tuple of str + dim_names : tuple of str The name of the dimensions in the data, in the order they occur. Must contain ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., multitaper algorithms) dimension. If including ``'taper'``, you should also pass a ``weights`` parameter. + + .. versionadded:: 1.8 weights : ndarray | None - Weights for the ``'taper'`` dimension, if present (see ``dimnames``). + Weights for the ``'taper'`` dimension, if present (see ``dim_names``). + + .. versionadded:: 1.8 %(verbose)s See Also @@ -1585,7 +1597,7 @@ def __init__( freqs, events=None, event_id=None, - dimnames=("epoch", "channel", "freq"), + dim_names=("epoch", "channel", "freq"), weights=None, *, verbose=None, @@ -1593,7 +1605,7 @@ def __init__( # (epoch, channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension - if list(dimnames).index("epoch") != 0: + if list(dim_names).index("epoch") != 0: raise ValueError("'epoch' must be the first dimension of `data`.") if events is not None and data.shape[0] != events.shape[0]: raise ValueError( @@ -1601,14 +1613,14 @@ def __init__( f"dimension of `events` ({events.shape[0]})." ) - _check_data_shape(data, info, freqs, dimnames, weights, is_epoched=True) + _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=True) self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], - dims=dimnames, + dims=dim_names, freqs=freqs, inst_type_str="Array", data_type=( diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index d9e1917ce61..980df42d791 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -439,15 +439,6 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path): assert_array_equal(spect1.freqs, spect2.freqs) -def _get_dimnames(kind, method, output, average): - dimnames = ("epoch", "channel") if kind == "epochs" else ("channel",) - if method == "welch": - dimnames += ("freq",) if average else ("freq", "segment") - else: # i.e. multitaper - dimnames += ("freq",) if output == "power" else ("taper", "freq") - return dimnames - - def test_spectrum_array_errors(): """Test (Epochs)SpectrumArray constructor errors.""" n_epochs = 10 @@ -457,23 +448,23 @@ def test_spectrum_array_errors(): sfreq = 100 rng = np.random.default_rng(44) data = rng.random((n_epochs, n_chans, n_freqs)) - dimnames = ("epoch", "channel", "freq") + dim_names = ("epoch", "channel", "freq") info = create_info(n_chans, sfreq, "eeg") # test incorrect ndims (for SpectrumArray; allows 2-3D data) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): - SpectrumArray(data[0, 0, :], info, freqs, dimnames=dimnames) + SpectrumArray(data[0, 0, :], info, freqs, dim_names=dim_names) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): - SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames) + SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names) # test incorrect ndims (for EpochsSpectrumArray; allows 3-4D data) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): - EpochsSpectrumArray(data[0, :, :], info, freqs, dimnames=dimnames) + EpochsSpectrumArray(data[0, :, :], info, freqs, dim_names=dim_names) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): EpochsSpectrumArray( - np.expand_dims(data, axis=(3, 4)), info, freqs, dimnames=dimnames + np.expand_dims(data, axis=(3, 4)), info, freqs, dim_names=dim_names ) # test incorrect epochs location with pytest.raises(ValueError, match="'epoch' must be the first dimension"): - EpochsSpectrumArray(data, info, freqs, dimnames=("channel", "epoch", "freq")) + EpochsSpectrumArray(data, info, freqs, dim_names=("channel", "epoch", "freq")) # test mismatching events shape events = np.vstack( ( @@ -483,36 +474,40 @@ def test_spectrum_array_errors(): ) ).T with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): - EpochsSpectrumArray(data, info, freqs, events, dimnames=dimnames) + EpochsSpectrumArray(data, info, freqs, events, dim_names=dim_names) # test data-dimname mismatch with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"): - EpochsSpectrumArray(data, info, freqs, dimnames=dimnames[:-1]) - # test unrecognised dimnames (for SpectrumArray; epoch not allowed) - with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"): - SpectrumArray(data[0, :, :], info, freqs, dimnames=("epoch", "channel")) - # test unrecognised dimnames (for EpochsSpectrumArray) - with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "notfreq")) - # test missing dimnames + EpochsSpectrumArray(data, info, freqs, dim_names=dim_names[:-1]) + # test unrecognised dim_names (for SpectrumArray; epoch not allowed) + with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"): + SpectrumArray(data[0, :, :], info, freqs, dim_names=("epoch", "channel")) + # test unrecognised dim_names (for EpochsSpectrumArray) + with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"): + EpochsSpectrumArray( + data, info, freqs, dim_names=("epoch", "channel", "notfreq") + ) + # test missing dim_names with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "channel")) + EpochsSpectrumArray( + data, info, freqs, dim_names=("epoch", "channel", "channel") + ) with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "freq")) + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "freq")) with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "epoch", "epoch")) + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "epoch", "epoch")) # test incorrect channel location (for SpectrumArray; must be 1st dim) with pytest.raises(ValueError, match="'channel' must be the first dimension"): - SpectrumArray(data[0, :, :], info, freqs, dimnames=("freq", "channel")) + SpectrumArray(data[0, :, :], info, freqs, dim_names=("freq", "channel")) # test incorrect channel location (for EpochsSpectrumArray; must be 2nd dim) with pytest.raises(ValueError, match="'channel' must be the second dimension"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "channel")) + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "channel")) # test mismatching number of channels with pytest.raises(ValueError, match=r"number of channels.*good data channels"): - EpochsSpectrumArray(data[:, :-1, :], info, freqs, dimnames=dimnames) + EpochsSpectrumArray(data[:, :-1, :], info, freqs, dim_names=dim_names) # test incorrect taper position with pytest.raises(ValueError, match="'taper' must be the second to last dim"): EpochsSpectrumArray( - np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames + ("taper",) + np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names + ("taper",) ) # test incorrect weight size with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): @@ -520,7 +515,7 @@ def test_spectrum_array_errors(): np.expand_dims(data, axis=2), info, freqs, - dimnames=("epoch", "channel", "taper", "freq"), + dim_names=("epoch", "channel", "taper", "freq"), weights=None, ) with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): @@ -528,7 +523,7 @@ def test_spectrum_array_errors(): np.expand_dims(data, axis=2), info, freqs, - dimnames=("epoch", "channel", "taper", "freq"), + dim_names=("epoch", "channel", "taper", "freq"), weights=np.ones((1, 2, 1)), ) # test incorrect segment position @@ -537,11 +532,11 @@ def test_spectrum_array_errors(): np.expand_dims(data, axis=2), info, freqs, - dimnames=("epoch", "channel", "segment", "freq"), + dim_names=("epoch", "channel", "segment", "freq"), ) # test mismatching number of frequencies with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"): - EpochsSpectrumArray(data[:, :, :-1], info, freqs, dimnames=dimnames) + EpochsSpectrumArray(data[:, :, :-1], info, freqs, dim_names=dim_names) @pytest.mark.parametrize( @@ -554,7 +549,11 @@ def test_spectrum_array_errors(): ) def test_spectrum_array(kind, method, output, average, tmp_path, request): """Test EpochsSpectrumArray and SpectrumArray constructors.""" - dimnames = _get_dimnames(kind, method, output, average) + dim_names = ("epoch", "channel") if kind == "epochs" else ("channel",) + if method == "welch": + dim_names += ("freq",) if average else ("freq", "segment") + else: # i.e. multitaper + dim_names += ("freq",) if output == "power" else ("taper", "freq") if method == "welch" and output == "power" and average: spectrum = request.getfixturevalue(f"{kind}_spectrum") else: @@ -569,7 +568,7 @@ def test_spectrum_array(kind, method, output, average, tmp_path, request): data=data, info=spectrum.info, freqs=freqs, - dimnames=dimnames, + dim_names=dim_names, weights=spectrum.weights, ) _check_spectrum_equivalent(spectrum, spect_arr, tmp_path) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index ff9e11ee776..57a0999fd1e 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2922,11 +2922,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["notes_plot_psd_meth"] = _notes_plot_psd.format("method") docdict["notes_spectrum_array"] = """ -It is assumed that the data passed in represent spectral *power* (not amplitude, -phase, model coefficients, etc) and downstream methods (such as +If the data passed in is real-valued, it is assumed to represent spectral *power* (not +amplitude, phase, etc), and downstream methods (such as :meth:`~mne.time_frequency.SpectrumArray.plot`) assume power data. If you pass in -something other than power, at the very least axis labels will be inaccurate (and -other things may also not work or be incorrect). +real-valued data that is not power, axis labels will be incorrect. + +If the data passed in is complex-valued, it is assumed to represent Fourier +coefficients. Downstream plotting methods will treat the data as such, attempting to +convert this to power before visualisation. If you pass in complex-valued data that is +not Fourier coefficients, axis labels will be incorrect. """ docdict["notes_timefreqs_tfr_plot_joint"] = """ From 23ae24935aa808cd9be2df74c117f47692614eaa Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 31 Jul 2024 18:21:11 +0200 Subject: [PATCH 13/16] Fix literals formatting --- mne/time_frequency/spectrum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index cc643490915..902d1c70e30 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1089,7 +1089,7 @@ class Spectrum(BaseSpectrum): Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s - method : "welch" | "multitaper" + method : ``'welch'``| ``'multitaper'`` The method used to compute the spectrum. nave : int | None The number of trials averaged together when generating the spectrum. ``None`` @@ -1376,7 +1376,7 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s - method : "welch" | "multitaper" + method : ``'welch'``| ``'multitaper'`` The method used to compute the spectrum. weights : array | None The weights for each taper. Only present if spectra computed with From 172d36d4f337d478053cfca6f27b33b4374ad3e8 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 31 Jul 2024 18:35:30 +0200 Subject: [PATCH 14/16] Remove changelog entry --- doc/changes/devel/12747.newfeature.rst | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 doc/changes/devel/12747.newfeature.rst diff --git a/doc/changes/devel/12747.newfeature.rst b/doc/changes/devel/12747.newfeature.rst deleted file mode 100644 index 2957117b778..00000000000 --- a/doc/changes/devel/12747.newfeature.rst +++ /dev/null @@ -1,3 +0,0 @@ -Add support for storing Fourier coefficients in :class:`mne.time_frequency.Spectrum`, -:class:`mne.time_frequency.EpochsSpectrum`, :class:`mne.time_frequency.SpectrumArray`, -and :class:`mne.time_frequency.EpochsSpectrumArray` objects, by `Thomas Binns`_. \ No newline at end of file From d81c6c4e0da55992f5c8bdb31a00e7fefa22ccf6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 31 Jul 2024 18:36:42 +0200 Subject: [PATCH 15/16] Restore changelog entry --- doc/changes/devel/12747.newfeature.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 doc/changes/devel/12747.newfeature.rst diff --git a/doc/changes/devel/12747.newfeature.rst b/doc/changes/devel/12747.newfeature.rst new file mode 100644 index 00000000000..2957117b778 --- /dev/null +++ b/doc/changes/devel/12747.newfeature.rst @@ -0,0 +1,3 @@ +Add support for storing Fourier coefficients in :class:`mne.time_frequency.Spectrum`, +:class:`mne.time_frequency.EpochsSpectrum`, :class:`mne.time_frequency.SpectrumArray`, +and :class:`mne.time_frequency.EpochsSpectrumArray` objects, by `Thomas Binns`_. \ No newline at end of file From c823201ffc641191a65fca789c0165e52cad4b75 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 31 Jul 2024 13:08:03 -0400 Subject: [PATCH 16/16] FIX: Workaround --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b9033adb8d1..c0fabebc5c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,6 +172,8 @@ doc = [ "ipython!=8.7.0", "selenium", "intersphinx_registry>=0.2405.27", + # https://github.com/sphinx-contrib/sphinxcontrib-towncrier/issues/92 + "towncrier<24.7", ] dev = ["mne[test,doc]", "rcssmin"]