From e221c6cb71ec42f7eebe1c5ea81089d78c91d4a8 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 2 Sep 2024 16:48:22 -0400 Subject: [PATCH 01/55] MAINT: Fix for dipy deprecation (#12823) --- .git-blame-ignore-revs | 1 + examples/inverse/morph_volume_stc.py | 1 + mne/morph.py | 8 ++++++- mne/preprocessing/ieeg/_volume.py | 4 +++- mne/tests/test_morph.py | 6 +++++- mne/transforms.py | 32 +++++++++++++++++++--------- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index a59ba48ab01..054b0c65924 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -13,3 +13,4 @@ e39995d9be6fc831c7a4a59f09b7a7c0a41ae315 # 12588, percent formatting 1c5b39ff1d99bbcb2fc0e0071a989b3f3845ff30 # 12603, ruff UP028 b8b168088cb474f27833f5f9db9d60abe00dca83 # 12779, PR JSONs ee64eba6f345e895e3d5e7d2804fa6aa2dac2e6d # 12781, Header unification +362f9330925fb79a6adc19a42243672676dec63e # 12799, UP038 diff --git a/examples/inverse/morph_volume_stc.py b/examples/inverse/morph_volume_stc.py index 7a9db303e90..24b23fc374e 100644 --- a/examples/inverse/morph_volume_stc.py +++ b/examples/inverse/morph_volume_stc.py @@ -20,6 +20,7 @@ result will be plotted, showing the fsaverage T1 weighted anatomical MRI, overlaid with the morphed volumetric source estimate. """ + # Author: Tommy Clausner # # License: BSD-3-Clause diff --git a/mne/morph.py b/mne/morph.py index 69ea3c3fab2..9c475bff1e9 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -1160,7 +1160,13 @@ def _compute_morph_sdr(mri_from, mri_to, niter_affine, niter_sdr, zooms): ) = _compute_volume_registration( mri_from, mri_to, zooms=zooms, niter=niter, pipeline=pipeline ) - pre_affine = AffineMap(pre_affine, to_shape, to_affine, from_shape, from_affine) + pre_affine = AffineMap( + pre_affine, + domain_grid_shape=to_shape, + domain_grid2world=to_affine, + codomain_grid_shape=from_shape, + codomain_grid2world=from_affine, + ) return to_shape, zooms, to_affine, pre_affine, sdr_morph diff --git a/mne/preprocessing/ieeg/_volume.py b/mne/preprocessing/ieeg/_volume.py index 8289e1defcd..b4997b2e3f8 100644 --- a/mne/preprocessing/ieeg/_volume.py +++ b/mne/preprocessing/ieeg/_volume.py @@ -83,7 +83,9 @@ def warp_montage(montage, moving, static, reg_affine, sdr_morph, verbose=None): # now, apply SDR morph if sdr_morph is not None: ch_coords = sdr_morph.transform_points( - ch_coords, sdr_morph.domain_grid2world, sdr_morph.domain_world2grid + ch_coords, + coord2world=sdr_morph.domain_grid2world, + world2coord=sdr_morph.domain_world2grid, ) # back to voxels but now for the static image diff --git a/mne/tests/test_morph.py b/mne/tests/test_morph.py index 1762049eb9f..7cb7d0cb9d9 100644 --- a/mne/tests/test_morph.py +++ b/mne/tests/test_morph.py @@ -1137,7 +1137,11 @@ def test_resample_equiv(from_shape, from_affine, to_shape, to_affine, order, see interp = "linear" if order == 1 else "nearest" got_dipy = dipy.align.imaffine.AffineMap( - None, to_shape, to_affine, from_shape, from_affine + None, + domain_grid_shape=to_shape, + domain_grid2world=to_affine, + codomain_grid_shape=from_shape, + codomain_grid2world=from_affine, ).transform(from_data, interpolation=interp, resample_only=True) # XXX possibly some error in dipy or nibabel (/SciPy), or some boundary # condition? diff --git a/mne/transforms.py b/mne/transforms.py index 6b819f58753..024ff0073a4 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -1818,10 +1818,10 @@ def _compute_volume_registration( sigma_diff_vox = sigma_diff_mm / current_zoom affine_map = AffineMap( reg_affine, # apply registration here - static_zoomed.shape, - static_affine, - moving_zoomed.shape, - moving_affine, + domain_grid_shape=static_zoomed.shape, + domain_grid2world=static_affine, + codomain_grid_shape=moving_zoomed.shape, + codomain_grid2world=moving_affine, ) moving_zoomed = affine_map.transform(moving_zoomed) metric = metrics.CCMetric( @@ -1829,10 +1829,16 @@ def _compute_volume_registration( sigma_diff=sigma_diff_vox, radius=max(int(np.ceil(2 * sigma_diff_vox)), 1), ) - sdr = imwarp.SymmetricDiffeomorphicRegistration(metric, niter[step]) + sdr = imwarp.SymmetricDiffeomorphicRegistration( + metric, + level_iters=niter[step], + ) with wrapped_stdout(indent=" ", cull_newlines=True): sdr_morph = sdr.optimize( - static_zoomed, moving_zoomed, static_affine, static_affine + static_zoomed, + moving_zoomed, + static_grid2world=static_affine, + moving_grid2world=static_affine, ) moved_zoomed = sdr_morph.transform(moving_zoomed) else: @@ -1841,8 +1847,8 @@ def _compute_volume_registration( moved_zoomed, reg_affine = affine_registration( moving_zoomed, static_zoomed, - moving_affine, - static_affine, + moving_affine=moving_affine, + static_affine=static_affine, nbins=32, metric="MI", pipeline=pipeline_options[step], @@ -1936,7 +1942,11 @@ def apply_volume_registration( moving -= cval static, static_affine = np.asarray(static.dataobj), static.affine affine_map = AffineMap( - reg_affine, static.shape, static_affine, moving.shape, moving_affine + reg_affine, + domain_grid_shape=static.shape, + domain_grid2world=static_affine, + codomain_grid_shape=moving.shape, + codomain_grid2world=moving_affine, ) reg_data = affine_map.transform(moving, interpolation=interpolation) if sdr_morph is not None: @@ -2029,7 +2039,9 @@ def apply_volume_registration_points( if sdr_morph is not None: _require_version("dipy", "SDR morph", "1.6.0") locs = sdr_morph.transform_points( - locs, sdr_morph.domain_grid2world, sdr_morph.domain_world2grid + locs, + coord2world=sdr_morph.domain_grid2world, + world2coord=sdr_morph.domain_world2grid, ) locs = apply_trans( Transform( # to static voxels From 3f1c7803f73862dec3e4e85e36e916f2a6d642c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 23:22:42 +0000 Subject: [PATCH 02/55] [pre-commit.ci] pre-commit autoupdate (#12825) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b17205310b8..bdbd926ebc7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.2 + rev: v0.6.3 hooks: - id: ruff name: ruff lint mne From 013d6c9f630342e5af7633641162c7b85417782a Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 4 Sep 2024 20:05:32 +0200 Subject: [PATCH 03/55] Fix docstring typos (#12826) --- mne/preprocessing/hfc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/preprocessing/hfc.py b/mne/preprocessing/hfc.py index 02bfba1eae6..f8a65510a9a 100644 --- a/mne/preprocessing/hfc.py +++ b/mne/preprocessing/hfc.py @@ -16,7 +16,7 @@ def compute_proj_hfc( ): """Generate projectors to perform homogeneous/harmonic correction to data. - Remove evironmental fields from magentometer data by assuming it is + Remove environmental fields from magnetometer data by assuming it is explained as a homogeneous :footcite:`TierneyEtAl2021` or harmonic field :footcite:`TierneyEtAl2022`. Useful for arrays of OPMs. @@ -26,7 +26,7 @@ def compute_proj_hfc( order : int The order of the spherical harmonic basis set to use. Set to 1 to use only the homogeneous field component (default), 2 to add gradients, 3 - to add quadrature terms etc. + to add quadrature terms, etc. picks : str | array_like | slice | None Channels to include. Default of ``'meg'`` (same as None) will select all non-reference MEG channels. Use ``('meg', 'ref_meg')`` to include From 3ecdd6b2102b00f97d7ad98ca2483a0f79e90b30 Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Fri, 6 Sep 2024 15:58:05 +0200 Subject: [PATCH 04/55] Add update to `.elc` montage reader (#12830) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/changes/devel/12830.newfeature.rst | 1 + mne/channels/_standard_montage_utils.py | 23 +++++++++++++++++++---- mne/channels/tests/test_montage.py | 25 ++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 doc/changes/devel/12830.newfeature.rst diff --git a/doc/changes/devel/12830.newfeature.rst b/doc/changes/devel/12830.newfeature.rst new file mode 100644 index 00000000000..4d51229392d --- /dev/null +++ b/doc/changes/devel/12830.newfeature.rst @@ -0,0 +1 @@ +:func:`mne.channels.read_custom_montage` may now read a newer version of the ``.elc`` ASA Electrode file format, by `Stefan Appelhoff`_. diff --git a/mne/channels/_standard_montage_utils.py b/mne/channels/_standard_montage_utils.py index 26d9c1434fe..ca4066e9a15 100644 --- a/mne/channels/_standard_montage_utils.py +++ b/mne/channels/_standard_montage_utils.py @@ -230,6 +230,11 @@ def _check_dupes_odict(ch_names, pos): def _read_elc(fname, head_size): """Read .elc files. + The `.elc` files are so-called "asa electrode files". ASA here stands for + Advances Source Analysis, and is a software package developed and sold by + the ANT Neuro company. They provide a device for sensor digitization, called + 'xensor', which produces the `.elc` files. + Parameters ---------- fname : str @@ -241,12 +246,12 @@ def _read_elc(fname, head_size): Returns ------- montage : instance of DigMontage - The montage in [m]. + The montage units are [m]. """ fid_names = ("Nz", "LPA", "RPA") - ch_names_, pos = [], [] with open(fname) as fid: + # Read units # _read_elc does require to detect the units. (see _mgh_or_standard) for line in fid: if "UnitPosition" in line: @@ -258,15 +263,25 @@ def _read_elc(fname, head_size): for line in fid: if "Positions\n" in line: break + + # Read positions pos = [] for line in fid: if "Labels\n" in line: break - pos.append(list(map(float, line.split()))) + if ":" in line: + # Of the 'new' format: `E01 : 5.288 -3.658 119.693` + pos.append(list(map(float, line.split(":")[1].split()))) + else: + # Of the 'old' format: `5.288 -3.658 119.693` + pos.append(list(map(float, line.split()))) + + # Read labels + ch_names_ = [] for line in fid: if not line or not set(line) - {" "}: break - ch_names_.append(line.strip(" ").strip("\n")) + ch_names_.extend(line.strip(" ").strip("\n").split()) pos = np.array(pos) * scale if head_size is not None: diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 92a489adc35..de251fb2872 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -300,7 +300,30 @@ def test_documented(): ), "elc", None, - id="ASA electrode", + id="old ASA electrode (elc)", + ), + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "NumberPositions= 96\n" + "UnitPosition mm\n" + "Positions\n" + "E01 : 5.288 -3.658 119.693\n" + "E02 : 59.518 -4.031 101.404\n" + "E03 : 29.949 -50.988 98.145\n" + "Labels\n" + "E01 E02 E03\n" + ), + make_dig_montage( + ch_pos={ + "E01": [0.005288, -0.003658, 0.119693], + "E02": [0.059518, -0.004031, 0.101404], + "E03": [0.029949, -0.050988, 0.098145], + }, + ), + "elc", + None, + id="new ASA electrode (elc)", ), pytest.param( partial(read_custom_montage, head_size=1), From c46733a4812f5d19370699c7b18cb1c1c4ded9fa Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 6 Sep 2024 12:11:20 -0400 Subject: [PATCH 05/55] MAINT: No sklearn nightly for now (#12832) --- tools/install_pre_requirements.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh index 88c21f52225..85c797be7aa 100755 --- a/tools/install_pre_requirements.sh +++ b/tools/install_pre_requirements.sh @@ -25,11 +25,13 @@ if [[ "${PLATFORM}" == "Linux" ]]; then fi python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 \ --index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \ - "numpy>=2.1.0.dev0" "scikit-learn>=1.6.dev0" "scipy>=1.15.0.dev0" \ + "numpy>=2.1.0.dev0" "scipy>=1.15.0.dev0" \ "statsmodels>=0.15.0.dev0" "pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \ $OTHERS # No Numba because it forces an old NumPy version +# No sklearn from SPNW until we figure out https://github.com/scikit-learn/scikit-learn/pull/29677 +pip install $STD_ARGS --upgrade scikit-learn if [[ "${PLATFORM}" == "Linux" ]]; then echo "pymatreader" From c993ae5d65b3c161876da407e293e6784c6e8ad9 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 6 Sep 2024 18:20:35 +0200 Subject: [PATCH 06/55] [API] Deprecate average parameter from CSP and SPoC plotting methods (#12829) --- doc/changes/devel/12829.apichange.rst | 1 + mne/decoding/csp.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 doc/changes/devel/12829.apichange.rst diff --git a/doc/changes/devel/12829.apichange.rst b/doc/changes/devel/12829.apichange.rst new file mode 100644 index 00000000000..d0bd4c12a46 --- /dev/null +++ b/doc/changes/devel/12829.apichange.rst @@ -0,0 +1 @@ +Deprecate ``average`` parameter in ``plot_filters`` and ``plot_patterns`` methods of the :class:`mne.decoding.CSP` and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 7bd77361fa6..b0e5d7b4bc8 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -18,6 +18,7 @@ copy_doc, fill_doc, pinv, + warn, ) from .base import BaseEstimator from .mixin import TransformerMixin @@ -369,6 +370,9 @@ def plot_patterns( if components is None: components = np.arange(self.n_components) + if average is not None: + warn("`average` is deprecated and will be removed in 1.10.", FutureWarning) + # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): @@ -500,6 +504,9 @@ def plot_filters( if components is None: components = np.arange(self.n_components) + if average is not None: + warn("`average` is deprecated and will be removed in 1.10.", FutureWarning) + # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): From f3a3ca4430e1d4b9c539e7949c946f4f83bdb43f Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 6 Sep 2024 19:37:20 +0200 Subject: [PATCH 07/55] [DOC] Fix misleading `fit_transform` docstrings (#12827) Co-authored-by: Eric Larson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/changes/devel/12827.other.rst | 1 + mne/decoding/csp.py | 55 +++++++++++++++++++++++++++++-- mne/decoding/ssd.py | 25 ++++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 doc/changes/devel/12827.other.rst diff --git a/doc/changes/devel/12827.other.rst b/doc/changes/devel/12827.other.rst new file mode 100644 index 00000000000..3ccbaa0bff6 --- /dev/null +++ b/doc/changes/devel/12827.other.rst @@ -0,0 +1 @@ +Improve documentation clarity of ``fit_transform`` methods for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index b0e5d7b4bc8..6d92b5d17bd 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -15,7 +15,6 @@ _check_option, _validate_type, _verbose_safe_false, - copy_doc, fill_doc, pinv, warn, @@ -273,8 +272,31 @@ def inverse_transform(self, X): ) return X[:, np.newaxis, :] * self.patterns_[: self.n_components].T - @copy_doc(TransformerMixin.fit_transform) - def fit_transform(self, X, y, **fit_params): # noqa: D102 + def fit_transform(self, X, y=None, **fit_params): + """Fit CSP to data, then transform it. + + Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and + returns a transformed version of ``X``. + + Parameters + ---------- + X : array, shape (n_epochs, n_channels, n_times) + The data on which to estimate the CSP. + y : array, shape (n_epochs,) + The class for each epoch. + **fit_params : dict + Additional fitting parameters passed to the :meth:`mne.decoding.CSP.fit` + method. Not used for this class. + + Returns + ------- + X_csp : array, shape (n_epochs, n_components[, n_times]) + If ``self.transform_into == 'average_power'`` then returns the power of CSP + features averaged over time and shape is ``(n_epochs, n_components)``. If + ``self.transform_into == 'csp_space'`` then returns the data in CSP space + and shape is ``(n_epochs, n_components, n_times)``. + """ + # use parent TransformerMixin method but with custom docstring return super().fit_transform(X, y=y, **fit_params) @fill_doc @@ -953,3 +975,30 @@ def transform(self, X): space and shape is (n_epochs, n_components, n_times). """ return super().transform(X) + + def fit_transform(self, X, y=None, **fit_params): + """Fit SPoC to data, then transform it. + + Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and + returns a transformed version of ``X``. + + Parameters + ---------- + X : array, shape (n_epochs, n_channels, n_times) + The data on which to estimate the SPoC. + y : array, shape (n_epochs,) + The class for each epoch. + **fit_params : dict + Additional fitting parameters passed to the :meth:`mne.decoding.CSP.fit` + method. Not used for this class. + + Returns + ------- + X : array, shape (n_epochs, n_components[, n_times]) + If ``self.transform_into == 'average_power'`` then returns the power of CSP + features averaged over time and shape is ``(n_epochs, n_components)``. If + ``self.transform_into == 'csp_space'`` then returns the data in CSP space + and shape is ``(n_epochs, n_components, n_times)``. + """ + # use parent TransformerMixin method but with custom docstring + return super().fit_transform(X, y=y, **fit_params) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index f5f1ff94516..23e3136ce36 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -261,6 +261,31 @@ def transform(self, X): X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] return X_ssd + def fit_transform(self, X, y=None, **fit_params): + """Fit SSD to data, then transform it. + + Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and + returns a transformed version of ``X``. + + Parameters + ---------- + X : array, shape ([n_epochs, ]n_channels, n_times) + The input data from which to estimate the SSD. Either 2D array obtained from + continuous data or 3D array obtained from epoched data. + y : None + Ignored; exists for compatibility with scikit-learn pipelines. + **fit_params : dict + Additional fitting parameters passed to the :meth:`mne.decoding.SSD.fit` + method. Not used for this class. + + Returns + ------- + X_ssd : array, shape ([n_epochs, ]n_components, n_times) + The processed data. + """ + # use parent TransformerMixin method but with custom docstring + return super().fit_transform(X, y=y, **fit_params) + def get_spectral_ratio(self, ssd_sources): """Get the spectal signal-to-noise ratio for each spatial filter. From 5cd52993c0027a958cfbc4f6810dff5d180301e5 Mon Sep 17 00:00:00 2001 From: Clemens Brunner Date: Wed, 11 Sep 2024 17:47:02 +0200 Subject: [PATCH 08/55] Disable unicode symbols in `mne.sys_info()` on Windows (#12838) --- mne/utils/config.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mne/utils/config.py b/mne/utils/config.py index 3e811323d3e..29e890e6aa1 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -610,7 +610,7 @@ def sys_info( show_paths=False, *, dependencies="user", - unicode=True, + unicode="auto", check_version=True, ): """Print system information. @@ -627,8 +627,9 @@ def sys_info( dependencies : 'user' | 'developer' Show dependencies relevant for users (default) or for developers (i.e., output includes additional dependencies). - unicode : bool - Include Unicode symbols in output. + unicode : bool | "auto" + Include Unicode symbols in output. If "auto", corresponds to True on Linux and + macOS, and False on Windows. .. versionadded:: 0.24 check_version : bool | float @@ -641,6 +642,13 @@ def sys_info( _validate_type(dependencies, str) _check_option("dependencies", dependencies, ("user", "developer")) _validate_type(check_version, (bool, "numeric"), "check_version") + _validate_type(unicode, (bool, str), "unicode") + _check_option("unicode", unicode, ("auto", True, False)) + if unicode == "auto": + if platform.system() in ("Darwin", "Linux"): + unicode = True + else: # Windows + unicode = False ljust = 24 if dependencies == "developer" else 21 platform_str = platform.platform() From 828953e0ba584c149bcd2d6898d3599af1b30f40 Mon Sep 17 00:00:00 2001 From: Clemens Brunner Date: Wed, 11 Sep 2024 17:50:36 +0200 Subject: [PATCH 09/55] Determine total memory without psutil (#12787) --- doc/changes/devel/12787.other.rst | 1 + mne/utils/config.py | 60 ++++++++++++++++++++++++++++--- mne/utils/tests/test_config.py | 2 +- pyproject.toml | 2 +- 4 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 doc/changes/devel/12787.other.rst diff --git a/doc/changes/devel/12787.other.rst b/doc/changes/devel/12787.other.rst new file mode 100644 index 00000000000..1f53fdea066 --- /dev/null +++ b/doc/changes/devel/12787.other.rst @@ -0,0 +1 @@ +Use custom code in :func:`mne.sys_info` to get the amount of physical memory and a more informative CPU name instead of using the ``psutil`` package, by `Clemens Brunner`_. \ No newline at end of file diff --git a/mne/utils/config.py b/mne/utils/config.py index 29e890e6aa1..8c0b827c9a6 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -30,6 +30,10 @@ _temp_home_dir = None +class UnknownPlatformError(Exception): + """Exception raised for unknown platforms.""" + + def set_cache_dir(cache_dir): """Set the directory to be used for temporary file storage. @@ -605,6 +609,47 @@ def _get_gpu_info(): return out +def _get_total_memory(): + """Return the total memory of the system in bytes.""" + if platform.system() == "Windows": + o = subprocess.check_output( + [ + "powershell.exe", + "(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory", + ] + ).decode() + total_memory = int(o) + elif platform.system() == "Linux": + o = subprocess.check_output(["free", "-b"]).decode() + total_memory = int(o.splitlines()[1].split()[1]) + elif platform.system() == "Darwin": + o = subprocess.check_output(["sysctl", "hw.memsize"]).decode() + total_memory = int(o.split(":")[1].strip()) + else: + raise UnknownPlatformError("Could not determine total memory") + + return total_memory + + +def _get_cpu_brand(): + """Return the CPU brand string.""" + if platform.system() == "Windows": + o = subprocess.check_output( + ["powershell.exe", "(Get-CimInstance Win32_Processor).Name"] + ).decode() + cpu_brand = o.strip().splitlines()[-1] + elif platform.system() == "Linux": + o = subprocess.check_output(["grep", "model name", "/proc/cpuinfo"]).decode() + cpu_brand = o.splitlines()[0].split(": ")[1] + elif platform.system() == "Darwin": + o = subprocess.check_output(["sysctl", "machdep.cpu"]).decode() + cpu_brand = o.split("brand_string: ")[1].strip() + else: + cpu_brand = "?" + + return cpu_brand + + def sys_info( fid=None, show_paths=False, @@ -656,15 +701,20 @@ def sys_info( out("Platform".ljust(ljust) + platform_str + "\n") out("Python".ljust(ljust) + str(sys.version).replace("\n", " ") + "\n") out("Executable".ljust(ljust) + sys.executable + "\n") - out("CPU".ljust(ljust) + f"{platform.processor()} ") + try: + cpu_brand = _get_cpu_brand() + except Exception: + cpu_brand = "?" + out("CPU".ljust(ljust) + f"{cpu_brand} ") out(f"({multiprocessing.cpu_count()} cores)\n") out("Memory".ljust(ljust)) try: - import psutil - except ImportError: - out('Unavailable (requires "psutil" package)') + total_memory = _get_total_memory() + except UnknownPlatformError: + total_memory = "?" else: - out(f"{psutil.virtual_memory().total / float(2 ** 30):0.1f} GB\n") + total_memory = f"{total_memory / 1024**3:.1f}" # convert to GiB + out(f"{total_memory} GiB\n") out("\n") ljust -= 3 # account for +/- symbols libs = _get_numpy_libs() diff --git a/mne/utils/tests/test_config.py b/mne/utils/tests/test_config.py index 5636801d5ae..e7611081b55 100644 --- a/mne/utils/tests/test_config.py +++ b/mne/utils/tests/test_config.py @@ -110,7 +110,7 @@ def test_sys_info_basic(): assert "numpy" in out # replace all in-line whitespace with single space out = "\n".join(" ".join(o.split()) for o in out.splitlines()) - + assert "? GiB" not in out if platform.system() == "Darwin": assert "Platform macOS-" in out elif platform.system() == "Linux": diff --git a/pyproject.toml b/pyproject.toml index fbabaed1888..55f02774163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,6 @@ full-no-qt = [ "jupyter", "python-picard", "joblib", - "psutil", "dipy", "vtk", "nilearn", @@ -175,6 +174,7 @@ doc = [ "intersphinx_registry>=0.2405.27", # https://github.com/sphinx-contrib/sphinxcontrib-towncrier/issues/92 "towncrier<24.7", + "psutil", ] dev = ["mne[test,doc]", "rcssmin"] From ebab915ea723155ee0fed788bcc92c2ff16a4f7f Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Wed, 11 Sep 2024 20:39:41 +0200 Subject: [PATCH 10/55] hotfix elc reader (#12839) Co-authored-by: Eric Larson --- mne/channels/_standard_montage_utils.py | 10 ++++- mne/channels/montage.py | 6 ++- mne/channels/tests/test_montage.py | 56 ++++++++++++++++++++++++- tutorials/inverse/70_eeg_mri_coords.py | 4 +- 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/mne/channels/_standard_montage_utils.py b/mne/channels/_standard_montage_utils.py index ca4066e9a15..eb3dc10d10e 100644 --- a/mne/channels/_standard_montage_utils.py +++ b/mne/channels/_standard_montage_utils.py @@ -265,6 +265,7 @@ def _read_elc(fname, head_size): break # Read positions + new_style = False pos = [] for line in fid: if "Labels\n" in line: @@ -272,6 +273,7 @@ def _read_elc(fname, head_size): if ":" in line: # Of the 'new' format: `E01 : 5.288 -3.658 119.693` pos.append(list(map(float, line.split(":")[1].split()))) + new_style = True else: # Of the 'old' format: `5.288 -3.658 119.693` pos.append(list(map(float, line.split()))) @@ -281,7 +283,13 @@ def _read_elc(fname, head_size): for line in fid: if not line or not set(line) - {" "}: break - ch_names_.extend(line.strip(" ").strip("\n").split()) + if new_style: + # Not sure how this format would deal with spaces in channel labels, + # but none of my test files had this, so let's wait until it comes up. + parsed = line.strip(" ").strip("\n").split() + else: + parsed = [line.strip(" ").strip("\n")] + ch_names_.extend(parsed) pos = np.array(pos) * scale if head_size is not None: diff --git a/mne/channels/montage.py b/mne/channels/montage.py index ddf885452a1..bb2f2006ee4 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -1537,7 +1537,10 @@ def _read_eeglab_locations(fname): return ch_names, pos -def read_custom_montage(fname, head_size=HEAD_SIZE_DEFAULT, coord_frame=None): +@verbose +def read_custom_montage( + fname, head_size=HEAD_SIZE_DEFAULT, coord_frame=None, *, verbose=None +): """Read a montage from a file. Parameters @@ -1558,6 +1561,7 @@ def read_custom_montage(fname, head_size=HEAD_SIZE_DEFAULT, coord_frame=None): for most readers but ``"head"`` for EEGLAB. .. versionadded:: 0.20 + %(verbose)s Returns ------- diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index de251fb2872..d0c406473e8 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -325,6 +325,42 @@ def test_documented(): None, id="new ASA electrode (elc)", ), + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "ReferenceLabel\n" + "avg\n" + "UnitPosition mm\n" + "NumberPositions= 6\n" + "Positions\n" + "-69.2574 10.5895 -25.0009\n" + "3.3791 94.6594 32.2592\n" + "77.2856 12.0537 -30.2488\n" + "4.6147 121.8858 8.6370\n" + "-31.3669 54.0269 94.9191\n" + "-8.7495 56.5653 99.6655\n" + "Labels\n" + "LPA\n" + "Nz\n" + "RPA\n" + "EEG 000\n" + "EEG 001\n" + "EEG 002\n" + ), + make_dig_montage( + ch_pos={ + "EEG 000": [0.004615, 0.121886, 0.008637], + "EEG 001": [-0.031367, 0.054027, 0.094919], + "EEG 002": [-0.00875, 0.056565, 0.099665], + }, + nasion=[0.003379, 0.094659, 0.032259], + lpa=[-0.069257, 0.010589, -0.025001], + rpa=[0.077286, 0.012054, -0.030249], + ), + "elc", + None, + id="another old ASA electrode (elc)", + ), pytest.param( partial(read_custom_montage, head_size=1), ( @@ -545,8 +581,26 @@ def test_montage_readers(reader, file_content, expected_dig, ext, warning, tmp_p actual_ch_pos = dig_montage._get_ch_pos() expected_ch_pos = expected_dig._get_ch_pos() for kk in actual_ch_pos: - assert_allclose(actual_ch_pos[kk], expected_ch_pos[kk], atol=1e-5) + assert_allclose(actual_ch_pos[kk], expected_ch_pos[kk], atol=1e-5, err_msg=kk) assert len(dig_montage.dig) == len(expected_dig.dig) + for key in ("nasion", "lpa", "rpa"): + expected = [ + d + for d in expected_dig.dig + if d["kind"] == FIFF.FIFFV_POINT_CARDINAL + and d["ident"] == getattr(FIFF, f"FIFFV_POINT_{key.upper()}") + ] + got = [ + d + for d in dig_montage.dig + if d["kind"] == FIFF.FIFFV_POINT_CARDINAL + and d["ident"] == getattr(FIFF, f"FIFFV_POINT_{key.upper()}") + ] + assert len(expected) in (0, 1), key + assert len(got) in (0, 1), key + assert len(expected) == len(got) + if len(expected): + assert_allclose(got[0]["r"], expected[0]["r"], atol=1e-5, err_msg=key) for d1, d2 in zip(dig_montage.dig, expected_dig.dig): assert d1["coord_frame"] == d2["coord_frame"] for key in ("coord_frame", "ident", "kind"): diff --git a/tutorials/inverse/70_eeg_mri_coords.py b/tutorials/inverse/70_eeg_mri_coords.py index 9783435f26c..5feeca0d2bf 100644 --- a/tutorials/inverse/70_eeg_mri_coords.py +++ b/tutorials/inverse/70_eeg_mri_coords.py @@ -5,8 +5,8 @@ EEG source localization given electrode locations on an MRI =========================================================== -This tutorial explains how to compute the forward operator from EEG data when -the electrodes are in MRI voxel coordinates. +This tutorial explains how to compute the forward operator from EEG data when the +electrodes are in MRI voxel coordinates. """ # Authors: Eric Larson From 9a2f887146f080357a83104c4af2c95e2176fe88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:25:18 -0400 Subject: [PATCH 11/55] [pre-commit.ci] pre-commit autoupdate (#12837) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- .pre-commit-config.yaml | 2 +- tutorials/inverse/70_eeg_mri_coords.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bdbd926ebc7..c369dc630ea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.3 + rev: v0.6.4 hooks: - id: ruff name: ruff lint mne diff --git a/tutorials/inverse/70_eeg_mri_coords.py b/tutorials/inverse/70_eeg_mri_coords.py index 5feeca0d2bf..20d6b62e4c9 100644 --- a/tutorials/inverse/70_eeg_mri_coords.py +++ b/tutorials/inverse/70_eeg_mri_coords.py @@ -104,7 +104,12 @@ # You can also verify that these are correct (or manually convert voxels # to MRI coords) by looking at the points in Freeview or tkmedit. -dig_montage = read_custom_montage(fname_mon, head_size=None, coord_frame="mri") +dig_montage = read_custom_montage( + fname_mon, + head_size=None, + coord_frame="mri", + verbose="error", # because it contains a duplicate point +) dig_montage.plot() ############################################################################## From 5425ef42e41b6a427f3365e13ea57ecf9c0c12b0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 13 Sep 2024 09:05:26 -0400 Subject: [PATCH 12/55] MAINT: Remove BaseEstimator (#12834) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- doc/changes/devel/12834.dependency.rst | 2 + doc/conf.py | 28 +++ examples/visualization/evoked_whitening.py | 1 + mne/cov.py | 6 +- mne/decoding/__init__.pyi | 9 +- mne/decoding/base.py | 58 ++--- mne/decoding/csp.py | 3 +- mne/decoding/ems.py | 4 +- mne/decoding/mixin.py | 89 ------- mne/decoding/receptive_field.py | 18 +- mne/decoding/search_light.py | 28 +-- mne/decoding/ssd.py | 18 +- mne/decoding/tests/test_base.py | 74 +++--- mne/decoding/tests/test_csp.py | 16 +- mne/decoding/tests/test_ems.py | 8 +- mne/decoding/tests/test_receptive_field.py | 67 +++-- mne/decoding/tests/test_search_light.py | 57 ++--- mne/decoding/tests/test_ssd.py | 10 +- mne/decoding/tests/test_time_frequency.py | 7 +- mne/decoding/tests/test_transformer.py | 16 +- mne/decoding/time_delaying_ridge.py | 3 - mne/decoding/time_frequency.py | 3 +- mne/decoding/transformer.py | 3 +- mne/fixes.py | 276 ++++----------------- mne/preprocessing/tests/test_xdawn.py | 8 +- mne/tests/test_docstring_parameters.py | 15 +- pyproject.toml | 4 + tools/install_pre_requirements.sh | 4 +- tools/vulture_allowlist.py | 4 + tutorials/forward/90_compute_covariance.py | 9 +- 30 files changed, 283 insertions(+), 565 deletions(-) create mode 100644 doc/changes/devel/12834.dependency.rst delete mode 100644 mne/decoding/mixin.py diff --git a/doc/changes/devel/12834.dependency.rst b/doc/changes/devel/12834.dependency.rst new file mode 100644 index 00000000000..ca19423df87 --- /dev/null +++ b/doc/changes/devel/12834.dependency.rst @@ -0,0 +1,2 @@ +Importing from ``mne.decoding`` now explicitly requires ``scikit-learn`` to be installed, +by `Eric Larson`_. diff --git a/doc/conf.py b/doc/conf.py index 32eedf80dba..cb7510b34a7 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -209,6 +209,8 @@ "ColorbarBase": "matplotlib.colorbar.ColorbarBase", # sklearn "LeaveOneOut": "sklearn.model_selection.LeaveOneOut", + "MetadataRequest": "sklearn.utils.metadata_routing.MetadataRequest", + "estimator": "sklearn.base.BaseEstimator", # joblib "joblib.Parallel": "joblib.Parallel", # nibabel @@ -397,6 +399,9 @@ "mapping", "to", "any", + "pandas", + "polars", + "default", # unlinkable "CoregistrationUI", "mne_qt_browser.figure.MNEQtBrowser", @@ -600,6 +605,28 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): """.format(name.split(".")[-1], name).split("\n") +def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): + """Fix sklearn docstrings because they use autolink and we do not.""" + if ( + name.startswith("mne.decoding.") or name.startswith("mne.preprocessing.Xdawn") + ) and name.endswith( + ( + ".get_metadata_routing", + ".fit", + ".fit_transform", + ".set_output", + ".transform", + ) + ): + if ":Parameters:" in lines: + loc = lines.index(":Parameters:") + else: + loc = lines.index(":Returns:") + lines.insert(loc, "") + lines.insert(loc, ".. default-role:: autolink") + lines.insert(loc, "") + + # -- Other extension configuration ------------------------------------------- # Consider using http://magjac.com/graphviz-visual-editor for this @@ -1659,6 +1686,7 @@ def make_version(app, exception): def setup(app): """Set up the Sphinx app.""" app.connect("autodoc-process-docstring", append_attr_meth_examples) + app.connect("autodoc-process-docstring", fix_sklearn_inherited_docstrings) # High prio, will happen before SG app.connect("builder-inited", generate_credit_rst, priority=10) app.connect("builder-inited", report_scraper.set_dirs, priority=20) diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index 9a474d9ea36..ed05ae3ba11 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -16,6 +16,7 @@ ---------- .. footbibliography:: """ + # Authors: Alexandre Gramfort # Denis A. Engemann # diff --git a/mne/cov.py b/mne/cov.py index 4fbfcb8d518..8dba6f9b8a3 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -49,8 +49,8 @@ from .event import make_fixed_length_events from .evoked import EvokedArray from .fixes import ( - BaseEstimator, EmpiricalCovariance, + _EstimatorMixin, _logdet, _safe_svd, empirical_covariance, @@ -1512,7 +1512,7 @@ def _auto_low_rank_model( # Sklearn Estimators -class _RegCovariance(BaseEstimator): +class _RegCovariance(_EstimatorMixin): """Aux class.""" def __init__( @@ -1595,7 +1595,7 @@ def get_precision(self): return self.estimator_.get_precision() -class _ShrunkCovariance(BaseEstimator): +class _ShrunkCovariance(_EstimatorMixin): """Aux class.""" def __init__(self, store_precision, assume_centered, shrinkage=0.1): diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 4c37a6bc496..2b6c89b2140 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -21,10 +21,15 @@ __all__ = [ "cross_val_multiscore", "get_coef", ] -from .base import BaseEstimator, LinearModel, cross_val_multiscore, get_coef +from .base import ( + BaseEstimator, + LinearModel, + TransformerMixin, + cross_val_multiscore, + get_coef, +) from .csp import CSP, SPoC from .ems import EMS, compute_ems -from .mixin import TransformerMixin from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator from .ssd import SSD diff --git a/mne/decoding/base.py b/mne/decoding/base.py index d8cf6104fac..a8e457137da 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -8,9 +8,18 @@ import numbers import numpy as np -from scipy.sparse import issparse +from sklearn import model_selection as models +from sklearn.base import ( # noqa: F401 + BaseEstimator, + TransformerMixin, + clone, + is_classifier, +) +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import check_scoring +from sklearn.model_selection import KFold, StratifiedKFold, check_cv +from sklearn.utils import check_array, indexable -from ..fixes import BaseEstimator, _check_fit_params, _get_check_scoring from ..parallel import parallel_func from ..utils import _pl, logger, verbose, warn @@ -64,15 +73,10 @@ class LinearModel(BaseEstimator): def __init__(self, model=None): if model is None: - from sklearn.linear_model import LogisticRegression - model = LogisticRegression(solver="liblinear") self.model = model - def _more_tags(self): - return {"no_validation": True} - def __getattr__(self, attr): """Wrap to model for some attributes.""" if attr in LinearModel._model_attr_wrap: @@ -104,22 +108,14 @@ def fit(self, X, y, **fit_params): self : instance of LinearModel Returns the modified instance. """ - # Once we require sklearn 1.1+ we should do: - # from sklearn.utils import check_array - # X = check_array(X, input_name="X") - # y = check_array(y, dtype=None, ensure_2d=False, input_name="y") - if issparse(X): - raise TypeError("X should be a dense array, got sparse instead.") - X, y = np.asarray(X), np.asarray(y) - if X.ndim != 2: - raise ValueError( - f"LinearModel only accepts 2-dimensional X, got {X.shape} instead." - ) - if y.ndim > 2: - raise ValueError( - f"LinearModel only accepts up to 2-dimensional y, got {y.shape} " - "instead." - ) + X = check_array(X, input_name="X") + if y is not None: + y = check_array(y, dtype=None, ensure_2d=False, input_name="y") + if y.ndim > 2: + raise ValueError( + f"LinearModel only accepts up to 2-dimensional y, got {y.shape} " + "instead." + ) # fit the Model self.model.fit(X, y, **fit_params) @@ -153,16 +149,12 @@ def filters_(self): def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" # Detect whether classification or regression - from sklearn.base import is_classifier if estimator in ["classifier", "regressor"]: est_is_classifier = estimator == "classifier" else: est_is_classifier = is_classifier(estimator) # Setup CV - from sklearn import model_selection as models - from sklearn.model_selection import KFold, StratifiedKFold, check_cv - if isinstance(cv, int | np.int64): XFold = StratifiedKFold if est_is_classifier else KFold cv = XFold(n_splits=cv) @@ -391,12 +383,6 @@ def cross_val_multiscore( Array of scores of the estimator for each run of the cross validation. """ # This code is copied from sklearn - from sklearn.base import clone, is_classifier - from sklearn.model_selection._split import check_cv - from sklearn.utils import indexable - - check_scoring = _get_check_scoring() - X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) @@ -449,12 +435,16 @@ def _fit_and_score( ): """Fit estimator and compute scores for a given dataset split.""" # This code is adapted from sklearn + from sklearn.model_selection import _validation from sklearn.utils.metaestimators import _safe_split from sklearn.utils.validation import _num_samples # Adjust length of sample weights + fit_params = fit_params if fit_params is not None else {} - fit_params = _check_fit_params(X, fit_params, train) + fit_params = { + k: _validation._index_param_value(X, v, train) for k, v in fit_params.items() + } if parameters is not None: estimator.set_params(**parameters) diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 6d92b5d17bd..1261ca82055 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -6,6 +6,7 @@ import numpy as np from scipy.linalg import eigh +from sklearn.base import BaseEstimator, TransformerMixin from .._fiff.meas_info import create_info from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh @@ -19,8 +20,6 @@ pinv, warn, ) -from .base import BaseEstimator -from .mixin import TransformerMixin @fill_doc diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index 511a4ce1681..f6811de460d 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -5,15 +5,15 @@ from collections import Counter import numpy as np +from sklearn.base import BaseEstimator, TransformerMixin from .._fiff.pick import _picks_to_idx, pick_info, pick_types from ..parallel import parallel_func from ..utils import logger, verbose from .base import _set_cv -from .mixin import EstimatorMixin, TransformerMixin -class EMS(TransformerMixin, EstimatorMixin): +class EMS(BaseEstimator, TransformerMixin): """Transformer to compute event-matched spatial filters. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire diff --git a/mne/decoding/mixin.py b/mne/decoding/mixin.py deleted file mode 100644 index c660cb7eca7..00000000000 --- a/mne/decoding/mixin.py +++ /dev/null @@ -1,89 +0,0 @@ -# Authors: The MNE-Python contributors. -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - - -class TransformerMixin: - """Mixin class for all transformers in scikit-learn.""" - - def fit_transform(self, X, y=None, **fit_params): - """Fit to data, then transform it. - - Fits transformer to ``X`` and ``y`` with optional parameters - ``fit_params``, and returns a transformed version of ``X``. - - Parameters - ---------- - X : array, shape (n_samples, n_features) - Training set. - y : array, shape (n_samples,) - Target values or class labels. - **fit_params : dict - Additional fitting parameters passed to the ``fit`` method.. - - Returns - ------- - X_new : array, shape (n_samples, n_features_new) - Transformed array. - """ - # non-optimized default implementation; override when a better - # method is possible for a given clustering algorithm - if y is None: - # fit method of arity 1 (unsupervised transformation) - return self.fit(X, **fit_params).transform(X) - else: - # fit method of arity 2 (supervised transformation) - return self.fit(X, y, **fit_params).transform(X) - - -class EstimatorMixin: - """Mixin class for estimators.""" - - def get_params(self, deep=True): - """Get the estimator params. - - Parameters - ---------- - deep : bool - Deep. - """ - return - - def set_params(self, **params): - """Set parameters (mimics sklearn API). - - Parameters - ---------- - **params : dict - Extra parameters. - - Returns - ------- - inst : object - The instance. - """ - if not params: - return self - valid_params = self.get_params(deep=True) - for key, value in params.items(): - split = key.split("__", 1) - if len(split) > 1: - # nested objects case - name, sub_name = split - if name not in valid_params: - raise ValueError( - f"Invalid parameter {name} for estimator {self}. Check the list" - " of available parameters with `estimator.get_params().keys()`." - ) - sub_object = valid_params[name] - sub_object.set_params(**{sub_name: value}) - else: - # simple objects case - if key not in valid_params: - raise ValueError( - f"Invalid parameter {key} for estimator " - f"{self.__class__.__name__}. Check the list of available " - "parameters with `estimator.get_params().keys()`." - ) - setattr(self, key, value) - return self diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index c03596bb811..5fc985a81b3 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -6,8 +6,10 @@ import numpy as np from scipy.stats import pearsonr +from sklearn.base import clone, is_regressor +from sklearn.metrics import r2_score -from ..utils import _validate_type, fill_doc, pinv, verbose +from ..utils import _validate_type, fill_doc, pinv from .base import BaseEstimator, _check_estimator, get_coef from .time_delaying_ridge import TimeDelayingRidge @@ -65,7 +67,6 @@ class ReceptiveField(BaseEstimator): duration. Only used if ``estimator`` is float or None. .. versionadded:: 0.18 - %(verbose)s Attributes ---------- @@ -101,7 +102,6 @@ class ReceptiveField(BaseEstimator): .. footbibliography:: """ # noqa E501 - @verbose def __init__( self, tmin, @@ -114,12 +114,11 @@ def __init__( patterns=False, n_jobs=None, edge_correction=True, - verbose=None, ): - self.feature_names = feature_names - self.sfreq = float(sfreq) self.tmin = tmin self.tmax = tmax + self.sfreq = float(sfreq) + self.feature_names = feature_names self.estimator = 0.0 if estimator is None else estimator self.fit_intercept = fit_intercept self.scoring = scoring @@ -127,9 +126,6 @@ def __init__( self.n_jobs = n_jobs self.edge_correction = edge_correction - def _more_tags(self): - return {"no_validation": True} - def __repr__(self): # noqa: D105 s = f"tmin, tmax : ({self.tmin:.3f}, {self.tmax:.3f}), " estimator = self.estimator @@ -186,8 +182,6 @@ def fit(self, X, y): raise ValueError( f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring} " ) - from sklearn.base import clone, is_regressor - X, y, _, self._y_dim = self._check_dimensions(X, y) if self.tmin > self.tmax: @@ -514,8 +508,6 @@ def _corr_score(y_true, y, multioutput=None): def _r2_score(y_true, y, multioutput=None): - from sklearn.metrics import r2_score - return r2_score(y_true, y, multioutput=multioutput) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 01e55a263d3..64f38a60634 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,13 +5,14 @@ import logging import numpy as np -from scipy.sparse import issparse +from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.metrics import check_scoring +from sklearn.preprocessing import LabelEncoder +from sklearn.utils import check_array -from ..fixes import _get_check_scoring from ..parallel import parallel_func from ..utils import ProgressBar, _parse_verbose, array_split_idx, fill_doc, verbose -from .base import BaseEstimator, _check_estimator -from .mixin import TransformerMixin +from .base import _check_estimator @fill_doc @@ -56,9 +57,6 @@ def __init__( self.allow_2d = allow_2d self.verbose = verbose - def _more_tags(self): - return {"no_validation": True, "requires_fit": False} - @property def _estimator_type(self): return getattr(self.base_estimator, "_estimator_type", None) @@ -256,14 +254,9 @@ def decision_function(self, X): def _check_Xy(self, X, y=None): """Aux. function to check input data.""" # Once we require sklearn 1.1+ we should do something like: - # from sklearn.utils import check_array - # X = check_array(X, ensure_2d=False, input_name="X") - # y = check_array(y, dtype=None, ensure_2d=False, input_name="y") - if issparse(X): - raise TypeError("X should be a dense array, got sparse instead.") - X = np.asarray(X) + X = check_array(X, ensure_2d=False, allow_nd=True, input_name="X") if y is not None: - y = np.asarray(y) + y = check_array(y, dtype=None, ensure_2d=False, input_name="y") if len(X) != len(y) or len(y) < 1: raise ValueError("X and y must have the same length.") if X.ndim < 3: @@ -300,8 +293,6 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators) Score for each estimator/task. """ # noqa: E501 - check_scoring = _get_check_scoring() - X = self._check_Xy(X, y) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") @@ -357,8 +348,6 @@ def _sl_fit(estimator, X, y, pb, **fit_params): estimators_ : list of estimators The fitted estimators. """ - from sklearn.base import clone - estimators_ = list() for ii in range(X.shape[-1]): est = clone(estimator) @@ -600,7 +589,6 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators, n_slices) Score for each estimator / data slice couple. """ # noqa: E501 - check_scoring = _get_check_scoring() X = self._check_Xy(X, y) # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. @@ -719,8 +707,6 @@ def _gl_score(estimators, scoring, X, y, pb): def _fix_auc(scoring, y): - from sklearn.preprocessing import LabelEncoder - # This fixes sklearn's inability to compute roc_auc when y not in [0, 1] # scikit-learn/scikit-learn#6874 if scoring is not None: diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 23e3136ce36..4043aa99835 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -4,12 +4,12 @@ import numpy as np from scipy.linalg import eigh +from sklearn.base import BaseEstimator, TransformerMixin from .._fiff.pick import _picks_to_idx from ..cov import Covariance, _regularized_covariance from ..defaults import _handle_default from ..filter import filter_data -from ..fixes import BaseEstimator from ..rank import compute_rank from ..time_frequency import psd_array_welch from ..utils import ( @@ -20,7 +20,6 @@ fill_doc, logger, ) -from .mixin import TransformerMixin @fill_doc @@ -125,14 +124,8 @@ def __init__( "The signal band-pass must be within the noise " "band-pass!" ) - self.picks_ = _picks_to_idx(info, picks, none="data", exclude="bads") + self.picks = picks del picks - ch_types = info.get_channel_types(picks=self.picks_, unique=True) - if len(ch_types) > 1: - raise ValueError( - "At this point SSD only supports fitting " - "single channel types. Your info has %i types" % (len(ch_types)) - ) self.info = info self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) @@ -183,6 +176,13 @@ def fit(self, X, y=None): self : instance of SSD Returns the modified instance. """ + ch_types = self.info.get_channel_types(picks=self.picks, unique=True) + if len(ch_types) > 1: + raise ValueError( + "At this point SSD only supports fitting " + "single channel types. Your info has %i types" % (len(ch_types)) + ) + self.picks_ = _picks_to_idx(self.info, self.picks, none="data", exclude="bads") self._check_X(X) X_aux = X[..., self.picks_, :] diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 10d9950bbf7..25fbba3fafd 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -14,6 +14,30 @@ assert_equal, ) +pytest.importorskip("sklearn") + +from sklearn import svm +from sklearn.base import ( + BaseEstimator as sklearn_BaseEstimator, +) +from sklearn.base import ( + TransformerMixin as sklearn_TransformerMixin, +) +from sklearn.base import ( + is_classifier, + is_regressor, +) +from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge +from sklearn.model_selection import ( + GridSearchCV, + KFold, + StratifiedKFold, + cross_val_score, +) +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.utils.estimator_checks import parametrize_with_checks + from mne import EpochsArray, create_info from mne.decoding import GeneralizingEstimator, Scaler, TransformerMixin, Vectorizer from mne.decoding.base import ( @@ -25,8 +49,6 @@ ) from mne.decoding.search_light import SlidingEstimator -pytest.importorskip("sklearn") - def _make_data(n_samples=1000, n_features=5, n_targets=3): """Generate some testing data. @@ -70,18 +92,6 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): @pytest.mark.filterwarnings("ignore:invalid value encountered in cast.*:RuntimeWarning") def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" - from sklearn import svm - from sklearn.base import ( - BaseEstimator, - TransformerMixin, - is_classifier, - is_regressor, - ) - from sklearn.linear_model import Ridge - from sklearn.model_selection import GridSearchCV - from sklearn.pipeline import make_pipeline - from sklearn.preprocessing import StandardScaler - lm_classification = LinearModel() assert is_classifier(lm_classification) @@ -100,6 +110,8 @@ def test_get_coef(): assert is_regressor(lm_gs_regression) # Define a classifier, an invertible transformer and an non-invertible one. + assert BaseEstimator is sklearn_BaseEstimator + assert TransformerMixin is sklearn_TransformerMixin class Clf(BaseEstimator): def fit(self, X, y): @@ -223,9 +235,6 @@ def transform(self, X): ) def test_get_coef_inverse_transform(inverse, Scale, kwargs): """Test get_coef with and without inverse_transform.""" - from sklearn.linear_model import Ridge - from sklearn.pipeline import make_pipeline - lm_regression = LinearModel(Ridge()) X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1) # Check with search_light and combination of preprocessing ending with sl: @@ -254,9 +263,6 @@ def test_get_coef_inverse_transform(inverse, Scale, kwargs): def test_get_coef_multiclass(n_features, n_targets): """Test get_coef on multiclass problems.""" # Check patterns with more than 1 regressor - from sklearn.linear_model import LinearRegression, Ridge - from sklearn.pipeline import make_pipeline - X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) lm = LinearModel(LinearRegression()).fit(X, Y) assert_array_equal(lm.filters_.shape, lm.patterns_.shape) @@ -308,10 +314,6 @@ def test_get_coef_multiclass(n_features, n_targets): @pytest.mark.filterwarnings("ignore:'multi_class' was deprecated in.*:FutureWarning") def test_get_coef_multiclass_full(n_classes, n_channels, n_times): """Test a full example with pattern extraction.""" - from sklearn.linear_model import LogisticRegression - from sklearn.model_selection import StratifiedKFold - from sklearn.pipeline import make_pipeline - data = np.zeros((10 * n_classes, n_channels, n_times)) # Make only the first channel informative for ii in range(n_classes): @@ -347,8 +349,6 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times): def test_linearmodel(): """Test LinearModel class for computing filters and patterns.""" # check categorical target fit in standard linear model - from sklearn.linear_model import LinearRegression - rng = np.random.RandomState(0) clf = LinearModel() n, n_features = 20, 3 @@ -362,9 +362,6 @@ def test_linearmodel(): clf.fit(wrong_X, y) # check categorical target fit in standard linear model with GridSearchCV - from sklearn import svm - from sklearn.model_selection import GridSearchCV - parameters = {"kernel": ["linear"], "C": [1, 10]} clf = LinearModel( GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None) @@ -403,9 +400,6 @@ def test_linearmodel(): def test_cross_val_multiscore(): """Test cross_val_multiscore for computing scores on decoding over time.""" - from sklearn.linear_model import LinearRegression, LogisticRegression - from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score - logreg = LogisticRegression(solver="liblinear", random_state=0) # compare to cross-val-score @@ -462,19 +456,15 @@ def test_cross_val_multiscore(): assert_array_equal(manual, auto) -def test_sklearn_compliance(): +@parametrize_with_checks([LinearModel(LogisticRegression())]) +def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import LogisticRegression - from sklearn.utils.estimator_checks import check_estimator - - lm = LinearModel(LogisticRegression()) ignores = ( + "check_n_features_in", # maybe we should add this someday? "check_estimator_sparse_data", # we densify "check_estimators_overwrite_params", # self.model changes! "check_parameters_default_constructible", ) - for est, check in check_estimator(lm, generate_only=True): - if any(ignore in str(check) for ignore in ignores): - continue - check(est) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 79528dc5615..7a1a83feeaf 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -13,6 +13,13 @@ assert_equal, ) +pytest.importorskip("sklearn") + +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import StratifiedKFold, cross_val_score +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.svm import SVC + from mne import Epochs, compute_proj_raw, io, pick_types, read_events from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef from mne.decoding.csp import _ajd_pham @@ -255,11 +262,6 @@ def test_csp(): @pytest.mark.parametrize("reg", [None, 0.001, "oas"]) def test_regularized_csp(ch_type, rank, reg): """Test Common Spatial Patterns algorithm using regularized covariance.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import LogisticRegression - from sklearn.model_selection import StratifiedKFold, cross_val_score - from sklearn.pipeline import make_pipeline - raw = io.read_raw_fif(raw_fname).pick(ch_type, exclude="bads").load_data() n_orig = len(raw.ch_names) ch_decim = 2 @@ -373,10 +375,6 @@ def test_regularized_csp(ch_type, rank, reg): def test_csp_pipeline(): """Test if CSP works in a pipeline.""" - pytest.importorskip("sklearn") - from sklearn.pipeline import Pipeline - from sklearn.svm import SVC - csp = CSP(reg=1, norm_trace=False) svc = SVC() pipe = Pipeline([("CSP", csp), ("SVC", svc)]) diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index 6a5effc07b7..10774c0681a 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -8,6 +8,10 @@ import pytest from numpy.testing import assert_array_almost_equal, assert_equal +pytest.importorskip("sklearn") + +from sklearn.model_selection import StratifiedKFold + from mne import Epochs, io, pick_types, read_events from mne.decoding import EMS, compute_ems @@ -17,13 +21,9 @@ tmin, tmax = -0.2, 0.5 event_id = dict(aud_l=1, vis_l=3) -pytest.importorskip("sklearn") - def test_ems(): """Test event-matched spatial filters.""" - from sklearn.model_selection import StratifiedKFold - raw = io.read_raw_fif(raw_fname, preload=False) # create unequal number of events diff --git a/mne/decoding/tests/test_receptive_field.py b/mne/decoding/tests/test_receptive_field.py index d46289819b8..5d7e1ff0661 100644 --- a/mne/decoding/tests/test_receptive_field.py +++ b/mne/decoding/tests/test_receptive_field.py @@ -10,6 +10,11 @@ from numpy.fft import irfft, rfft from numpy.testing import assert_allclose, assert_array_equal, assert_equal +pytest.importorskip("sklearn") + +from sklearn.linear_model import Ridge +from sklearn.utils.estimator_checks import parametrize_with_checks + from mne.decoding import ReceptiveField, TimeDelayingRidge from mne.decoding.receptive_field import ( _SCORERS, @@ -79,9 +84,6 @@ def test_compute_reg_neighbors(): def test_rank_deficiency(): """Test signals that are rank deficient.""" # See GH#4253 - pytest.importorskip("sklearn") - from sklearn.linear_model import Ridge - N = 256 fs = 1.0 tmin, tmax = -50, 100 @@ -174,9 +176,6 @@ def test_time_delay(): @pytest.mark.parametrize("n_jobs", n_jobs_test) def test_receptive_field_basic(n_jobs): """Test model prep and fitting.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import Ridge - # Make sure estimator pulling works mod = Ridge() rng = np.random.RandomState(1337) @@ -372,9 +371,6 @@ def test_time_delaying_fast_calc(n_jobs): @pytest.mark.parametrize("n_jobs", n_jobs_test) def test_receptive_field_1d(n_jobs): """Test that the fast solving works like Ridge.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import Ridge - rng = np.random.RandomState(0) x = rng.randn(500, 1) for delay in range(-2, 3): @@ -433,9 +429,6 @@ def test_receptive_field_1d(n_jobs): @pytest.mark.parametrize("n_jobs", n_jobs_test) def test_receptive_field_nd(n_jobs): """Test multidimensional support.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import Ridge - # multidimensional rng = np.random.RandomState(3) x = rng.randn(1000, 3) @@ -552,9 +545,6 @@ def _make_data(n_feats, n_targets, n_samples, tmin, tmax): def test_inverse_coef(): """Test inverse coefficients computation.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import Ridge - tmin, tmax = 0.0, 10.0 n_feats, n_targets, n_samples = 3, 2, 1000 n_delays = int((tmax - tmin) + 1) @@ -583,9 +573,6 @@ def test_inverse_coef(): def test_linalg_warning(): """Test that warnings are issued when no regularization is applied.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import Ridge - n_feats, n_targets, n_samples = 5, 60, 50 X, y = _make_data(n_feats, n_targets, n_samples, tmin, tmax) for estimator in (0.0, Ridge(alpha=0.0)): @@ -596,12 +583,9 @@ def test_linalg_warning(): rf.fit(y, X) -def test_tdr_sklearn_compliance(): +@parametrize_with_checks([TimeDelayingRidge(0, 10, 1.0, 0.1, "laplacian", n_jobs=1)]) +def test_tdr_sklearn_compliance(estimator, check): """Test sklearn estimator compliance.""" - pytest.importorskip("sklearn") - from sklearn.utils.estimator_checks import check_estimator - - tdr = TimeDelayingRidge(0, 10, 1.0, 0.1, "laplacian", n_jobs=1) # We don't actually comply with a bunch of the regressor specs :( ignores = ( "check_supervised_y_no_nan", @@ -609,27 +593,36 @@ def test_tdr_sklearn_compliance(): "check_parameters_default_constructible", "check_estimators_unfitted", "_invariance", + "check_complex_data", + "check_estimators_empty_data_messages", + "check_estimators_nan_inf", + "check_supervised_y_2d", + "check_n_features_in", "check_fit2d_1sample", + "check_fit1d", + "check_fit2d_predict1d", + "check_requires_y_none", ) - for est, check in check_estimator(tdr, generate_only=True): - if any(ignore in str(check) for ignore in ignores): - continue - check(est) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) -def test_rf_sklearn_compliance(): +@pytest.mark.filterwarnings("ignore:.*invalid value encountered in subtract.*:") +@parametrize_with_checks([ReceptiveField(-1, 2, 1.0, estimator=Ridge(), patterns=True)]) +def test_rf_sklearn_compliance(estimator, check): """Test sklearn RF compliance.""" - pytest.importorskip("sklearn") - from sklearn.linear_model import Ridge - from sklearn.utils.estimator_checks import check_estimator - - rf = ReceptiveField(-1, 2, 1.0, estimator=Ridge(), patterns=True) ignores = ( "check_parameters_default_constructible", "_invariance", "check_fit2d_1sample", + # Should probably fix these? + "check_complex_data", + "check_dtype_object", + "check_estimators_empty_data_messages", + "check_n_features_in", + "check_fit2d_predict1d", ) - for est, check in check_estimator(rf, generate_only=True): - if any(ignore in str(check) for ignore in ignores): - continue - check(est) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 329b6b3d30f..fe605abca06 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -9,11 +9,22 @@ import pytest from numpy.testing import assert_array_equal, assert_equal +sklearn = pytest.importorskip("sklearn") + +from sklearn.base import BaseEstimator, clone, is_classifier +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.ensemble import BaggingClassifier +from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge +from sklearn.metrics import make_scorer, roc_auc_score +from sklearn.model_selection import cross_val_predict +from sklearn.multiclass import OneVsRestClassifier +from sklearn.pipeline import make_pipeline +from sklearn.svm import SVC +from sklearn.utils.estimator_checks import parametrize_with_checks + from mne.decoding.search_light import GeneralizingEstimator, SlidingEstimator from mne.decoding.transformer import Vectorizer -from mne.utils import _record_warnings, check_version, use_log_level - -sklearn = pytest.importorskip("sklearn") +from mne.utils import check_version, use_log_level NEW_MULTICLASS_SAMPLE_WEIGHT = check_version("sklearn", "1.4") @@ -35,14 +46,6 @@ def test_search_light(): # https://github.com/scikit-learn/scikit-learn/issues/27711 if platform.system() == "Windows" and check_version("numpy", "2.0.0.dev0"): pytest.skip("sklearn int_t / long long mismatch") - from sklearn.linear_model import LogisticRegression, Ridge - from sklearn.metrics import make_scorer, roc_auc_score - from sklearn.multiclass import OneVsRestClassifier - from sklearn.pipeline import make_pipeline - - with _record_warnings(): # NumPy module import - from sklearn.ensemble import BaggingClassifier - from sklearn.base import is_classifier logreg = OneVsRestClassifier(LogisticRegression(solver="liblinear", random_state=0)) @@ -197,11 +200,6 @@ def metadata_routing(): def test_generalization_light(metadata_routing): """Test GeneralizingEstimator.""" - from sklearn.linear_model import LogisticRegression - from sklearn.metrics import roc_auc_score - from sklearn.multiclass import OneVsRestClassifier - from sklearn.pipeline import make_pipeline - if NEW_MULTICLASS_SAMPLE_WEIGHT: clf = LogisticRegression(random_state=0) clf.set_fit_request(sample_weight=True) @@ -296,8 +294,6 @@ def test_generalization_light(metadata_routing): ) def test_verbose_arg(capsys, n_jobs, verbose): """Test controlling output with the ``verbose`` argument.""" - from sklearn.svm import SVC - X, y = make_data() clf = SVC() @@ -318,11 +314,6 @@ def test_verbose_arg(capsys, n_jobs, verbose): def test_cross_val_predict(): """Test cross_val_predict with predict_proba.""" - from sklearn.base import BaseEstimator, clone - from sklearn.discriminant_analysis import LinearDiscriminantAnalysis - from sklearn.linear_model import LinearRegression - from sklearn.model_selection import cross_val_predict - rng = np.random.RandomState(42) X = rng.randn(10, 1, 3) y = rng.randint(0, 2, 10) @@ -352,13 +343,9 @@ def predict_proba(self, X): @pytest.mark.slowtest -def test_sklearn_compliance(): +@parametrize_with_checks([SlidingEstimator(LogisticRegression(), allow_2d=True)]) +def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" - from sklearn.linear_model import LogisticRegression - from sklearn.utils.estimator_checks import check_estimator - - est = SlidingEstimator(LogisticRegression(), allow_2d=True) - ignores = ( "check_estimator_sparse_data", # we densify "check_classifiers_one_label_sample_weights", # don't handle singleton @@ -366,8 +353,12 @@ def test_sklearn_compliance(): "check_classifiers_train", "check_decision_proba_consistency", "check_parameters_default_constructible", + # Should probably fix these? + "check_estimators_unfitted", + "check_transformer_data_not_an_array", + "check_n_features_in", + "check_fit2d_predict1d", ) - for est, check in check_estimator(est, generate_only=True): - if any(ignore in str(check) for ignore in ignores): - continue - check(est) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index e306dffa2db..198feeb6532 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -8,6 +8,10 @@ import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal +pytest.importorskip("sklearn") + +from sklearn.pipeline import Pipeline + from mne import create_info, io from mne.decoding import CSP from mne.decoding.ssd import SSD @@ -150,8 +154,9 @@ def test_ssd(): ch_types = np.reshape([["mag"] * 10, ["eeg"] * 10], n_channels) info_2 = create_info(ch_names=n_channels, sfreq=sf, ch_types=ch_types) + ssd = SSD(info_2, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="At this point SSD"): - ssd = SSD(info_2, filt_params_signal, filt_params_noise) + ssd.fit(X) # Number of channels info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types="eeg") @@ -298,9 +303,6 @@ def test_ssd_epoched_data(): def test_ssd_pipeline(): """Test if SSD works in a pipeline.""" - pytest.importorskip("sklearn") - from sklearn.pipeline import Pipeline - sf = 250 X, A, S = simulate_data(n_trials=100, n_channels=20, n_samples=500) X_e = np.reshape(X, (100, 20, 500)) diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py index 52c4e9f1bc9..37e7d7d8dc2 100644 --- a/mne/decoding/tests/test_time_frequency.py +++ b/mne/decoding/tests/test_time_frequency.py @@ -7,14 +7,15 @@ import pytest from numpy.testing import assert_array_equal +pytest.importorskip("sklearn") + +from sklearn.base import clone + from mne.decoding.time_frequency import TimeFrequency def test_timefrequency(): """Test TimeFrequency.""" - pytest.importorskip("sklearn") - from sklearn.base import clone - # Init n_freqs = 3 freqs = [20, 21, 22] diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 098df9fa0f0..8dcc3ad74c7 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -13,6 +13,11 @@ assert_equal, ) +pytest.importorskip("sklearn") + +from sklearn.decomposition import PCA +from sklearn.kernel_ridge import KernelRidge + from mne import Epochs, io, pick_types, read_events from mne.decoding import ( FilterEstimator, @@ -23,7 +28,7 @@ Vectorizer, ) from mne.defaults import DEFAULTS -from mne.utils import check_version, use_log_level +from mne.utils import use_log_level tmin, tmax = -0.2, 0.5 event_id = dict(aud_l=1, vis_l=3) @@ -58,11 +63,6 @@ def test_scaler(info, method): y = epochs.events[:, -1] epochs_data_t = epochs_data.transpose([1, 0, 2]) - if method in ("mean", "median"): - if not check_version("sklearn"): - with pytest.raises((ImportError, RuntimeError), match=" module "): - Scaler(info, method) - return if info: info = epochs.info @@ -217,10 +217,6 @@ def test_vectorizer(): def test_unsupervised_spatial_filter(): """Test unsupervised spatial filter.""" - pytest.importorskip("sklearn") - from sklearn.decomposition import PCA - from sklearn.kernel_ridge import KernelRidge - raw = io.read_raw_fif(raw_fname) events = read_events(event_name) picks = pick_types( diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index 5d85556c954..e824a15be75 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -298,9 +298,6 @@ def __init__( self.edge_correction = edge_correction self.n_jobs = n_jobs - def _more_tags(self): - return {"no_validation": True} - @property def _smin(self): return int(round(self.tmin * self.sfreq)) diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index 20bcc40baca..de6ec52155b 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -3,11 +3,10 @@ # Copyright the MNE-Python contributors. import numpy as np +from sklearn.base import BaseEstimator, TransformerMixin from ..time_frequency.tfr import _compute_tfr from ..utils import _check_option, fill_doc, verbose -from .base import BaseEstimator -from .mixin import TransformerMixin @fill_doc diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index d3cdbc172ea..e293d108ba8 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -3,6 +3,7 @@ # Copyright the MNE-Python contributors. import numpy as np +from sklearn.base import BaseEstimator, TransformerMixin from .._fiff.pick import ( _pick_data_channels, @@ -15,8 +16,6 @@ from ..filter import filter_data from ..time_frequency import psd_array_multitaper from ..utils import _check_option, _validate_type, fill_doc, verbose -from .base import BaseEstimator -from .mixin import TransformerMixin class _ConstantScaler: diff --git a/mne/fixes.py b/mne/fixes.py index 18f4536d72b..e9d62fb42e6 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -19,9 +19,7 @@ import operator as operator_module import os import warnings -from io import StringIO from math import log -from pprint import pprint import numpy as np @@ -134,231 +132,6 @@ def _get_img_fdata(img): return data.astype(dtype) -############################################################################## -# adapted from scikit-learn - - -_DEFAULT_TAGS = { - "array_api_support": False, - "non_deterministic": False, - "requires_positive_X": False, - "requires_positive_y": False, - "X_types": ["2darray"], - "poor_score": False, - "no_validation": False, - "multioutput": False, - "allow_nan": False, - "stateless": False, - "multilabel": False, - "_skip_test": False, - "_xfail_checks": False, - "multioutput_only": False, - "binary_only": False, - "requires_fit": True, - "preserves_dtype": [np.float64], - "requires_y": False, - "pairwise": False, -} - - -class BaseEstimator: - """Base class for all estimators in scikit-learn. - - Notes - ----- - All estimators should specify all the parameters that can be set - at the class level in their ``__init__`` as explicit keyword - arguments (no ``*args`` or ``**kwargs``). - """ - - @classmethod - def _get_param_names(cls): - """Get parameter names for the estimator.""" - # fetch the constructor or the original constructor before - # deprecation wrapping if any - init = getattr(cls.__init__, "deprecated_original", cls.__init__) - if init is object.__init__: - # No explicit constructor to introspect - return [] - - # introspect the constructor arguments to find the model parameters - # to represent - init_signature = inspect.signature(init) - # Consider the constructor parameters excluding 'self' - parameters = [ - p - for p in init_signature.parameters.values() - if p.name != "self" and p.kind != p.VAR_KEYWORD - ] - for p in parameters: - if p.kind == p.VAR_POSITIONAL: - raise RuntimeError( - "scikit-learn estimators should always " - "specify their parameters in the signature" - " of their __init__ (no varargs)." - f" {cls} with constructor {init_signature} doesn't " - " follow this convention." - ) - # Extract and sort argument names excluding 'self' - return sorted([p.name for p in parameters]) - - def get_params(self, deep=True): - """Get parameters for this estimator. - - Parameters - ---------- - deep : bool, optional - If True, will return the parameters for this estimator and - contained subobjects that are estimators. - - Returns - ------- - params : dict - Parameter names mapped to their values. - """ - out = dict() - for key in self._get_param_names(): - # We need deprecation warnings to always be on in order to - # catch deprecated param values. - # This is set in utils/__init__.py but it gets overwritten - # when running under python3 somehow. - warnings.simplefilter("always", DeprecationWarning) - try: - with warnings.catch_warnings(record=True) as w: - value = getattr(self, key, None) - if len(w) and w[0].category is DeprecationWarning: - # if the parameter is deprecated, don't show it - continue - finally: - warnings.filters.pop(0) - - # XXX: should we rather test if instance of estimator? - if deep and hasattr(value, "get_params"): - deep_items = value.get_params().items() - out.update((key + "__" + k, val) for k, val in deep_items) - out[key] = value - return out - - def set_params(self, **params): - """Set the parameters of this estimator. - - The method works on simple estimators as well as on nested objects - (such as pipelines). The latter have parameters of the form - ``__`` so that it's possible to update each - component of a nested object. - - Parameters - ---------- - **params : dict - Parameters. - - Returns - ------- - inst : instance - The object. - """ - if not params: - # Simple optimisation to gain speed (inspect is slow) - return self - valid_params = self.get_params(deep=True) - for key, value in params.items(): - split = key.split("__", 1) - if len(split) > 1: - # nested objects case - name, sub_name = split - if name not in valid_params: - raise ValueError( - f"Invalid parameter {name} for estimator {self}. " - "Check the list of available parameters " - "with `estimator.get_params().keys()`." - ) - sub_object = valid_params[name] - sub_object.set_params(**{sub_name: value}) - else: - # simple objects case - if key not in valid_params: - raise ValueError( - f"Invalid parameter {key} for estimator " - f"{self.__class__.__name__}. " - "Check the list of available parameters " - "with `estimator.get_params().keys()`." - ) - setattr(self, key, value) - return self - - def __repr__(self): # noqa: D105 - params = StringIO() - pprint(self.get_params(deep=False), params) - params.seek(0) - class_name = self.__class__.__name__ - return f"{class_name}({params.read().strip()})" - - # __getstate__ and __setstate__ are omitted because they only contain - # conditionals that are not satisfied by our objects (e.g., - # ``if type(self).__module__.startswith('sklearn.')``. - - def _more_tags(self): - return _DEFAULT_TAGS - - def _get_tags(self): - collected_tags = {} - for base_class in reversed(inspect.getmro(self.__class__)): - if hasattr(base_class, "_more_tags"): - # need the if because mixins might not have _more_tags - # but might do redundant work in estimators - # (i.e. calling more tags on BaseEstimator multiple times) - more_tags = base_class._more_tags(self) - collected_tags.update(more_tags) - return collected_tags - - -# newer sklearn deprecates importing from sklearn.metrics.scoring, -# but older sklearn does not expose check_scoring in sklearn.metrics. -def _get_check_scoring(): - try: - from sklearn.metrics import check_scoring # noqa - except ImportError: - from sklearn.metrics.scorer import check_scoring # noqa - return check_scoring - - -def _check_fit_params(X, fit_params, indices=None): - """Check and validate the parameters passed during `fit`. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Data array. - - fit_params : dict - Dictionary containing the parameters passed at fit. - - indices : array-like of shape (n_samples,), default=None - Indices to be selected if the parameter has the same size as - `X`. - - Returns - ------- - fit_params_validated : dict - Validated parameters. We ensure that the values support - indexing. - """ - try: - from sklearn.utils.validation import ( - _check_fit_params as _sklearn_check_fit_params, - ) - - return _sklearn_check_fit_params(X, fit_params, indices) - except ImportError: - from sklearn.model_selection import _validation - - fit_params_validated = { - k: _validation._index_param_value(X, v, indices) - for k, v in fit_params.items() - } - return fit_params_validated - - ############################################################################### # Copied from sklearn to simplify code paths @@ -401,7 +174,54 @@ def empirical_covariance(X, assume_centered=False): return covariance -class EmpiricalCovariance(BaseEstimator): +class _EstimatorMixin: + def _param_names(self): + return inspect.getfullargspec(self.__init__).args[1:] + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Parameters + ---------- + deep : bool, default=True + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + out = dict() + for key in self._param_names(): + out[key] = getattr(self, key) + return out + + def set_params(self, **params): + """Set the parameters of this estimator. + + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``__`` so that it's possible to update each + component of a nested object. + + Parameters + ---------- + **params : dict + Estimator parameters. + + Returns + ------- + self : object + Estimator instance. + """ + param_names = self._param_names() + for key in params: + if key in param_names: + setattr(self, key, params[key]) + + +class EmpiricalCovariance(_EstimatorMixin): """Maximum likelihood covariance estimator. Read more in the :ref:`User Guide `. diff --git a/mne/preprocessing/tests/test_xdawn.py b/mne/preprocessing/tests/test_xdawn.py index 565a12e6017..c30fd5dcfd9 100644 --- a/mne/preprocessing/tests/test_xdawn.py +++ b/mne/preprocessing/tests/test_xdawn.py @@ -17,10 +17,12 @@ pick_types, read_events, ) -from mne.decoding import Vectorizer from mne.fixes import _safe_svd from mne.io import read_raw_fif -from mne.preprocessing.xdawn import Xdawn, _XdawnTransformer + +pytest.importorskip("sklearn") + +from mne.preprocessing.xdawn import Xdawn, _XdawnTransformer # noqa: E402 base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" @@ -353,6 +355,8 @@ def test_xdawn_decoding_performance(): from sklearn.pipeline import make_pipeline from sklearn.preprocessing import MinMaxScaler + from mne.decoding import Vectorizer + n_xdawn_comps = 3 expected_accuracy = 0.98 diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index b89e8920c42..d32e62a454e 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -109,7 +109,6 @@ def _func_name(func, cls=None): }, ), (list, {"append", "count", "extend", "index", "insert", "pop", "remove", "sort"}), - (mne.fixes.BaseEstimator, {"get_params", "set_params", "fit_transform"}), ) @@ -175,7 +174,12 @@ def test_docstring_parameters(): module = __import__(name, globals()) for submod in name.split(".")[1:]: module = getattr(module, submod) - classes = inspect.getmembers(module, inspect.isclass) + try: + classes = inspect.getmembers(module, inspect.isclass) + except ModuleNotFoundError as exc: # e.g., mne.decoding but no sklearn + if "'sklearn'" in str(exc): + continue + raise for cname, cls in classes: if cname.startswith("_"): continue @@ -326,7 +330,12 @@ def test_documented(): module = __import__(name, globals()) for submod in name.split(".")[1:]: module = getattr(module, submod) - classes = inspect.getmembers(module, inspect.isclass) + try: + classes = inspect.getmembers(module, inspect.isclass) + except ModuleNotFoundError as exc: # e.g., mne.decoding but no sklearn + if "'sklearn'" in str(exc): + continue + raise functions = inspect.getmembers(module, inspect.isfunction) checks = list(classes) + list(functions) for this_name, cf in checks: diff --git a/pyproject.toml b/pyproject.toml index 55f02774163..725b453a9aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,10 @@ ignore-decorators = [ "examples/preprocessing/eeg_bridging.py" = [ "E501", # line too long ] +"mne/decoding/tests/test_*.py" = [ + "E402", # Module level import not at top of file +] + [tool.pytest.ini_options] # -r f (failed), E (error), s (skipped), x (xfail), X (xpassed), w (warnings) diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh index 85c797be7aa..88c21f52225 100755 --- a/tools/install_pre_requirements.sh +++ b/tools/install_pre_requirements.sh @@ -25,13 +25,11 @@ if [[ "${PLATFORM}" == "Linux" ]]; then fi python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 \ --index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \ - "numpy>=2.1.0.dev0" "scipy>=1.15.0.dev0" \ + "numpy>=2.1.0.dev0" "scikit-learn>=1.6.dev0" "scipy>=1.15.0.dev0" \ "statsmodels>=0.15.0.dev0" "pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \ $OTHERS # No Numba because it forces an old NumPy version -# No sklearn from SPNW until we figure out https://github.com/scikit-learn/scikit-learn/pull/29677 -pip install $STD_ARGS --upgrade scikit-learn if [[ "${PLATFORM}" == "Linux" ]]; then echo "pymatreader" diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index d06eac34285..f083f733239 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -41,6 +41,10 @@ verbose_debug metadata_routing +# Decoding +_._more_tags +deep + # Backward compat or rarely used RawFIF estimate_head_mri_t diff --git a/tutorials/forward/90_compute_covariance.py b/tutorials/forward/90_compute_covariance.py index 718565ffe71..37c2329f439 100644 --- a/tutorials/forward/90_compute_covariance.py +++ b/tutorials/forward/90_compute_covariance.py @@ -5,11 +5,10 @@ Computing a covariance matrix ============================= -Many methods in MNE, including source estimation and some classification -algorithms, require covariance estimations from the recordings. -In this tutorial we cover the basics of sensor covariance computations and -construct a noise covariance matrix that can be used when computing the -minimum-norm inverse solution. For more information, see +Many methods in MNE, including source estimation and some classification algorithms, +require covariance estimations from the recordings. In this tutorial we cover the basics +of sensor covariance computations and construct a noise covariance matrix that can be +used when computing the minimum-norm inverse solution. For more information, see :ref:`minimum_norm_estimates`. """ From 528e04658c5ea7a5e55bcb5889ebb4c6d6847284 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Fri, 13 Sep 2024 15:20:12 +0200 Subject: [PATCH 13/55] [MRG] Add additional tests for ant reader (#12841) --- doc/changes/devel/12792.newfeature.rst | 2 +- mne/datasets/config.py | 4 +- mne/io/ant/ant.py | 4 +- mne/io/ant/tests/test_ant.py | 251 +++++++++++++++++++++---- pyproject.toml | 2 +- tools/vulture_allowlist.py | 1 + tutorials/io/20_reading_eeg_data.py | 29 ++- 7 files changed, 253 insertions(+), 40 deletions(-) diff --git a/doc/changes/devel/12792.newfeature.rst b/doc/changes/devel/12792.newfeature.rst index 8866b5c201a..81ef79c8a11 100644 --- a/doc/changes/devel/12792.newfeature.rst +++ b/doc/changes/devel/12792.newfeature.rst @@ -1 +1 @@ -Add reader for ANT Neuro files in the ``*.cnt`` format with :func:`~mne.io.read_raw_ant`, by `Mathieu Scheltienne`_ and `Eric Larson`_. +Add reader for ANT Neuro files in the ``*.cnt`` format with :func:`~mne.io.read_raw_ant`, by `Mathieu Scheltienne`_, `Eric Larson`_ and `Proloy Das`_. diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ca3345acb78..ccd4babacd9 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.155", + testing="0.156", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:a3ddc1fbfd48830207112db13c3fdd6a", + hash="md5:d94fe9f3abe949a507eaeb865fb84a3f", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f'tar.gz/{RELEASES["testing"]}' diff --git a/mne/io/ant/ant.py b/mne/io/ant/ant.py index dd1fb79638f..cc8dbe05dfe 100644 --- a/mne/io/ant/ant.py +++ b/mne/io/ant/ant.py @@ -122,7 +122,7 @@ def __init__( _validate_type(impedance_annotation, (str,), "impedance_annotation") if len(impedance_annotation) == 0: raise ValueError("The impedance annotation cannot be an empty string.") - cnt = read_cnt(str(fname)) + cnt = read_cnt(fname) # parse channels, sampling frequency, and create info ch_info = read_info(cnt) # load in 2 lines for compat with antio 0.2 and 0.3 ch_names, ch_units, ch_refs = ch_info[0], ch_info[1], ch_info[2] @@ -198,7 +198,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): i_start = max(start, first_samp) i_stop = min(stop, this_n_times + first_samp) # read and scale data array - cnt = read_cnt(str(self._filenames[fi])) + cnt = read_cnt(self._filenames[fi]) one = read_data(cnt, i_start, i_stop) _scale_data(one, ch_units) data_view = data[:, i_start - start : i_stop - start] diff --git a/mne/io/ant/tests/test_ant.py b/mne/io/ant/tests/test_ant.py index bf1e4556808..825ef0bde95 100644 --- a/mne/io/ant/tests/test_ant.py +++ b/mne/io/ant/tests/test_ant.py @@ -6,24 +6,26 @@ import warnings from pathlib import Path -from typing import TYPE_CHECKING import numpy as np import pytest from numpy.testing import assert_allclose from mne import Annotations +from mne._fiff.constants import FIFF from mne.datasets import testing from mne.io import BaseRaw, read_raw, read_raw_ant, read_raw_brainvision from mne.io.ant.ant import RawANT -if TYPE_CHECKING: - from pathlib import Path - -pytest.importorskip("antio", minversion="0.3.0") +pytest.importorskip("antio", minversion="0.4.0") data_path = testing.data_path(download=False) / "antio" +TypeDataset = dict[ + str, dict[str, Path] | str | int | tuple[str, str, str] | dict[str, str] | None +] + + def read_raw_bv(fname: Path) -> BaseRaw: """Read a brainvision file exported from eego. @@ -45,7 +47,7 @@ def read_raw_bv(fname: Path) -> BaseRaw: @pytest.fixture(scope="module") -def ca_208() -> dict[str, dict[str, Path] | str | int | dict[str, str | int]]: +def ca_208() -> TypeDataset: """Return the paths to the CA_208 dataset containing 64 channel gel recordings.""" cnt = { "short": data_path / "CA_208" / "test_CA_208.cnt", @@ -71,7 +73,41 @@ def ca_208() -> dict[str, dict[str, Path] | str | int | dict[str, str | int]]: @pytest.fixture(scope="module") -def andy_101() -> dict[str, dict[str, Path] | str | int | dict[str, str | int]]: +def ca_208_refs() -> TypeDataset: + """Return the paths and info to the CA_208_refs dataset. + + The following montage was applid on export: + - highpass: 0.3 Hz - lowpass: 30 Hz + - Fp1, Fpz, Fp2 referenced to Fz + - CP3, CP4 referenced to Cz + - others to CPz + """ + cnt = { + "short": data_path / "CA_208_refs" / "test-ref.cnt", + "legacy": data_path / "CA_208_refs" / "test-ref-legacy.cnt", + } + bv = { + "short": cnt["short"].with_suffix(".vhdr"), + } + return { + "cnt": cnt, + "bv": bv, + "n_eeg": 64, + "n_misc": 0, + "meas_date": "2024-09-09-10-57-44+0000", + "patient_info": { + "name": "antio test", + "his_id": "", + "birthday": "2024-08-14", + "sex": 0, + }, + "machine_info": ("eego", "EE_225", ""), + "hospital": "", + } + + +@pytest.fixture(scope="module") +def andy_101() -> TypeDataset: """Return the path and info to the andy_101 dataset.""" cnt = { "short": data_path / "andy_101" / "Andy_101-raw.cnt", @@ -94,8 +130,96 @@ def andy_101() -> dict[str, dict[str, Path] | str | int | dict[str, str | int]]: } +@pytest.fixture(scope="module") +def na_271() -> TypeDataset: + """Return the path to a dataset containing 128 channel recording. + + The recording was done with an NA_271 net dipped in saline solution. + """ + cnt = { + "short": data_path / "NA_271" / "test-na-271.cnt", + "legacy": data_path / "NA_271" / "test-na-271-legacy.cnt", + } + bv = { + "short": cnt["short"].with_suffix(".vhdr"), + } + return { + "cnt": cnt, + "bv": bv, + "n_eeg": 128, + "n_misc": 0, + "meas_date": "2024-09-06-10-45-07+0000", + "patient_info": { + "name": "antio test", + "his_id": "", + "birthday": "2024-08-14", + "sex": 0, + }, + "machine_info": ("eego", "EE_226", ""), + "hospital": "", + } + + +@pytest.fixture(scope="module") +def na_271_bips() -> TypeDataset: + """Return the path to a dataset containing 128 channel recording. + + The recording was done with an NA_271 net dipped in saline solution and includes + bipolar channels. + """ + cnt = { + "short": data_path / "NA_271_bips" / "test-na-271.cnt", + "legacy": data_path / "NA_271_bips" / "test-na-271-legacy.cnt", + } + bv = { + "short": cnt["short"].with_suffix(".vhdr"), + } + return { + "cnt": cnt, + "bv": bv, + "n_eeg": 128, + "n_misc": 6, + "meas_date": "2024-09-06-10-37-23+0000", + "patient_info": { + "name": "antio test", + "his_id": "", + "birthday": "2024-08-14", + "sex": 0, + }, + "machine_info": ("eego", "EE_226", ""), + "hospital": "", + } + + +@pytest.fixture(scope="module") +def user_annotations() -> TypeDataset: + """Return the path to a dataset containing user annotations with floating pins.""" + cnt = { + "short": data_path / "user_annotations" / "test-user-annotation.cnt", + "legacy": data_path / "user_annotations" / "test-user-annotation-legacy.cnt", + } + bv = { + "short": cnt["short"].with_suffix(".vhdr"), + } + return { + "cnt": cnt, + "bv": bv, + "n_eeg": 64, + "n_misc": 0, + "meas_date": "2024-08-29-16-15-44+0000", + "patient_info": { + "name": "test test", + "his_id": "", + "birthday": "2024-02-06", + "sex": 0, + }, + "machine_info": ("eego", "EE_225", ""), + "hospital": "", + } + + @testing.requires_testing_data -@pytest.mark.parametrize("dataset", ["ca_208", "andy_101"]) +@pytest.mark.parametrize("dataset", ["ca_208", "andy_101", "na_271"]) def test_io_data(dataset, request): """Test loading of .cnt file.""" dataset = request.getfixturevalue(dataset) @@ -138,7 +262,7 @@ def test_io_data(dataset, request): @testing.requires_testing_data -@pytest.mark.parametrize("dataset", ["ca_208", "andy_101"]) +@pytest.mark.parametrize("dataset", ["ca_208", "andy_101", "na_271"]) def test_io_info(dataset, request): """Test the ifo loaded from a .cnt file.""" dataset = request.getfixturevalue(dataset) @@ -158,9 +282,7 @@ def test_io_info(dataset, request): @testing.requires_testing_data -def test_io_info_parse_misc( - ca_208: dict[str, dict[str, Path] | str | int | dict[str, str | int]], -): +def test_io_info_parse_misc(ca_208: TypeDataset): """Test parsing misc channels from a .cnt file.""" raw_cnt = read_raw_ant(ca_208["cnt"]["short"]) with pytest.warns( @@ -172,10 +294,26 @@ def test_io_info_parse_misc( assert raw_cnt.get_channel_types() == ["eeg"] * len(raw_cnt.ch_names) +def test_io_info_parse_non_standard_misc(na_271_bips: TypeDataset): + """Test parsing misc channels with modified names from a .cnt file.""" + with pytest.warns( + RuntimeWarning, match="EEG channels are not referenced to the same electrode" + ): + raw = read_raw_ant(na_271_bips["cnt"]["short"], misc=None) + assert raw.get_channel_types() == ["eeg"] * ( + na_271_bips["n_eeg"] + na_271_bips["n_misc"] + ) + raw = read_raw_ant( + na_271_bips["cnt"]["short"], preload=False, misc=r".{0,1}E.{1}G|Aux|Audio" + ) + assert ( + raw.get_channel_types() + == ["eeg"] * na_271_bips["n_eeg"] + ["misc"] * na_271_bips["n_misc"] + ) + + @testing.requires_testing_data -def test_io_info_parse_eog( - ca_208: dict[str, dict[str, Path] | str | int | dict[str, str | int]], -): +def test_io_info_parse_eog(ca_208: TypeDataset): """Test parsing EOG channels from a .cnt file.""" raw_cnt = read_raw_ant(ca_208["cnt"]["short"], eog="EOG") assert len(raw_cnt.ch_names) == ca_208["n_eeg"] + ca_208["n_misc"] @@ -186,7 +324,9 @@ def test_io_info_parse_eog( @testing.requires_testing_data -@pytest.mark.parametrize("dataset", ["andy_101", "ca_208"]) +@pytest.mark.parametrize( + "dataset", ["andy_101", "ca_208", "na_271", "user_annotations"] +) def test_subject_info(dataset, request): """Test reading the subject info.""" dataset = request.getfixturevalue(dataset) @@ -202,7 +342,9 @@ def test_subject_info(dataset, request): @testing.requires_testing_data -@pytest.mark.parametrize("dataset", ["andy_101", "ca_208"]) +@pytest.mark.parametrize( + "dataset", ["andy_101", "ca_208", "na_271", "user_annotations"] +) def test_machine_info(dataset, request): """Test reading the machine info.""" dataset = request.getfixturevalue(dataset) @@ -215,9 +357,7 @@ def test_machine_info(dataset, request): @testing.requires_testing_data -def test_io_amp_disconnection( - ca_208: dict[str, dict[str, Path] | str | int | dict[str, str | int]], -): +def test_io_amp_disconnection(ca_208: TypeDataset): """Test loading of .cnt file with amplifier disconnection.""" raw_cnt = read_raw_ant(ca_208["cnt"]["amp-dc"]) raw_bv = read_raw_bv(ca_208["bv"]["amp-dc"]) @@ -249,10 +389,7 @@ def test_io_amp_disconnection( @testing.requires_testing_data @pytest.mark.parametrize("description", ["impedance", "test"]) -def test_io_impedance( - ca_208: dict[str, dict[str, Path] | str | int | dict[str, str | int]], - description: str, -): +def test_io_impedance(ca_208: TypeDataset, description: str): """Test loading of impedances from a .cnt file.""" raw_cnt = read_raw_ant(ca_208["cnt"]["amp-dc"], impedance_annotation=description) assert isinstance(raw_cnt.impedances, list) @@ -267,9 +404,7 @@ def test_io_impedance( @testing.requires_testing_data -def test_io_segments( - ca_208: dict[str, dict[str, Path] | str | int | dict[str, str | int]], -): +def test_io_segments(ca_208: TypeDataset): """Test reading a .cnt file with segents (start/stop).""" raw_cnt = read_raw_ant(ca_208["cnt"]["start-stop"]) raw_bv = read_raw_bv(ca_208["bv"]["start-stop"]) @@ -277,9 +412,7 @@ def test_io_segments( @testing.requires_testing_data -def test_annotations_and_preload( - ca_208: dict[str, dict[str, Path] | str | int | dict[str, str | int]], -): +def test_annotations_and_preload(ca_208: TypeDataset): """Test annotation loading with preload True/False.""" raw_cnt_preloaded = read_raw_ant(ca_208["cnt"]["short"], preload=True) assert len(raw_cnt_preloaded.annotations) == 2 # impedance measurements, start/end @@ -302,9 +435,61 @@ def test_annotations_and_preload( @testing.requires_testing_data -def test_read_raw( - ca_208: dict[str, dict[str, Path] | str | int | dict[str, str | int]], -): +def test_read_raw(ca_208: TypeDataset): """Test loading through read_raw.""" raw = read_raw(ca_208["cnt"]["short"]) assert isinstance(raw, RawANT) + + +@testing.requires_testing_data +@pytest.mark.parametrize("preload", [True, False]) +def test_read_raw_with_user_annotations(user_annotations: TypeDataset, preload: bool): + """Test reading raw objects which have user annotations.""" + raw = read_raw_ant(user_annotations["cnt"]["short"], preload=preload) + assert raw.annotations + assert "1000/user-annot" in raw.annotations.description + assert "1000/user-annot-2" in raw.annotations.description + + +@testing.requires_testing_data +@pytest.mark.parametrize("dataset", ["na_271", "user_annotations"]) +def test_read_raw_legacy_format(dataset, request): + """Test reading the legacy CNT format.""" + dataset = request.getfixturevalue(dataset) + raw_cnt = read_raw_ant(dataset["cnt"]["short"]) # preload=False + raw_bv = read_raw_bv(dataset["bv"]["short"]) + assert raw_cnt.ch_names == raw_bv.ch_names + assert raw_cnt.info["sfreq"] == raw_bv.info["sfreq"] + assert ( + raw_cnt.get_channel_types() + == ["eeg"] * dataset["n_eeg"] + ["misc"] * dataset["n_misc"] + ) + assert_allclose( + (raw_bv.info["meas_date"] - raw_cnt.info["meas_date"]).total_seconds(), + 0, + atol=1e-3, + ) + + +@testing.requires_testing_data +def test_read_raw_custom_reference(ca_208_refs: TypeDataset): + """Test reading a CNT file with custom EEG references.""" + with pytest.warns( + RuntimeWarning, match="EEG channels are not referenced to the same electrode" + ): + raw = read_raw_ant(ca_208_refs["cnt"]["short"], preload=False) + for ch in raw.info["chs"]: + assert ch["coil_type"] == FIFF.FIFFV_COIL_EEG + bipolars = ("Fp1-Fz", "Fpz-Fz", "Fp2-Fz", "CP3-Cz", "CP4-Cz") + with pytest.warns( + RuntimeWarning, match="EEG channels are not referenced to the same electrode" + ): + raw = read_raw_ant( + ca_208_refs["cnt"]["short"], preload=False, bipolars=bipolars + ) + assert all(elt in raw.ch_names for elt in bipolars) + for ch in raw.info["chs"]: + if ch["ch_name"] in bipolars: + assert ch["coil_type"] == FIFF.FIFFV_COIL_EEG_BIPOLAR + else: + assert ch["coil_type"] == FIFF.FIFFV_COIL_EEG diff --git a/pyproject.toml b/pyproject.toml index 725b453a9aa..e82012b8d6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ full-no-qt = [ "snirf", "defusedxml", "neo", - "antio", + "antio>=0.4.0", ] full = ["mne[full-no-qt]", "PyQt6!=6.6.0", "PyQt6-Qt6!=6.6.0,!=6.7.0"] full-pyqt6 = ["mne[full]"] diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index f083f733239..ce5e95124d4 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -79,6 +79,7 @@ # mne/io/ant/tests/test_ant.py andy_101 +na_271 # mne/io/snirf/tests/test_snirf.py _.dataTimeSeries diff --git a/tutorials/io/20_reading_eeg_data.py b/tutorials/io/20_reading_eeg_data.py index 155cbbfc061..2544e57f60c 100644 --- a/tutorials/io/20_reading_eeg_data.py +++ b/tutorials/io/20_reading_eeg_data.py @@ -120,6 +120,7 @@ non-data channel does not fit to the sphere, it is assigned a z-value of 0. .. warning:: + Reading channel locations from the file header may be dangerous, as the x_coord and y_coord in the ELECTLOC section of the header do not necessarily translate to absolute locations. Furthermore, EEG electrode locations that @@ -127,6 +128,32 @@ If you are not sure about the channel locations in the header, using a montage is encouraged. +.. warning:: + + ANT Neuro also uses a file format with the ``.cnt`` extension, but it is different + from the Neuroscan CNT format. The ANT Neuro format is supported by the function + :func:`mne.io.read_raw_ant`. + + +.. _import-ant: + +ANT Neuro CNT (.cnt) +==================== + +CNT files from the eego software of ANT Neuro can be read using +:func:`mne.io.read_raw_ant`. The channels can be automatically recognized as auxiliary +``'misc'`` channels if the regular expression in the argument ``misc`` correctly +captures the channel names. Same for EOG channels with the regular expression in the +argument ``eog``. Note that if a montage with specific bipolar channels is applied on +export, they can be loaded as EEG bipolar channel pairs by providing the argument +``bipolars``. All other EEG channels will be loaded as regular EEG channels referenced +to the same electrode. + +.. warning:: + + Neuroscan also uses a file format with the ``.cnt`` extension, but it is different + from the eego CNT format. The Neuroscan CNT format is supported by the function + :func:`mne.io.read_raw_cnt`. .. _import-egi: @@ -248,6 +275,6 @@ When using locations of fiducial points, the digitization data are converted to the MEG head coordinate system employed in the MNE software, see :ref:`coordinate_systems`. -""" # noqa:E501 +""" # noqa: E501 # %% From 2f299276fecedcb8b3626ea40fc4f1aecff51bce Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 13 Sep 2024 09:20:37 -0400 Subject: [PATCH 14/55] MAINT: Improve PR template (#12833) --- .github/PULL_REQUEST_TEMPLATE.md | 15 +++++++++++---- tools/github_actions_test.sh | 3 +++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 837e4eaa1e2..231488d2d47 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -15,13 +15,20 @@ Again, thanks for contributing! --> -#### Reference issue -Example: Fixes #1234. +#### Reference issue (if any) + + #### What does this implement/fix? -Explain your changes. + + #### Additional information -Any additional information you think is important. + + diff --git a/tools/github_actions_test.sh b/tools/github_actions_test.sh index 78cc063d016..4cdd202223f 100755 --- a/tools/github_actions_test.sh +++ b/tools/github_actions_test.sh @@ -19,10 +19,13 @@ if [[ ! -z "$CONDA_ENV" ]] && [[ "${RUNNER_OS}" != "Windows" ]]; then cd .. INSTALL_PATH=$(python -c "import mne, pathlib; print(str(pathlib.Path(mne.__file__).parents[1]))") echo "Copying tests from $(pwd)/mne-python/mne/ to ${INSTALL_PATH}/mne/" + echo "::group::rsync" rsync -a --partial --progress --prune-empty-dirs --exclude="*.pyc" --include="**/" --include="**/tests/*" --include="**/tests/data/**" --exclude="**" ./mne-python/mne/ ${INSTALL_PATH}/mne/ + echo "::endgroup::" cd $INSTALL_PATH echo "Executing from $(pwd)" fi + set -x pytest -m "${CONDITION}" --tb=short --cov=mne --cov-report xml --color=yes --junit-xml=$JUNIT_PATH -vv ${USE_DIRS} set +x From 670330a1ead288f11cc3ffd5ca9d65136ca55ccd Mon Sep 17 00:00:00 2001 From: Scott Huberty <52462026+scott-huberty@users.noreply.github.com> Date: Fri, 13 Sep 2024 12:05:02 -0700 Subject: [PATCH 15/55] Use SI units for eyetracking data, update tutorials (#12846) --- doc/changes/devel/12846.bugfix.rst | 1 + .../visualization/eyetracking_plot_heatmap.py | 4 ++ mne/defaults.py | 18 +++---- tutorials/io/70_reading_eyetracking_data.py | 50 ++++++++++++------- .../preprocessing/90_eyetracking_data.py | 40 ++++++++++++--- 5 files changed, 79 insertions(+), 34 deletions(-) create mode 100644 doc/changes/devel/12846.bugfix.rst diff --git a/doc/changes/devel/12846.bugfix.rst b/doc/changes/devel/12846.bugfix.rst new file mode 100644 index 00000000000..2b40e77490e --- /dev/null +++ b/doc/changes/devel/12846.bugfix.rst @@ -0,0 +1 @@ +Enforce SI units for Eyetracking data (eyegaze data should be radians of visual angle, not pixels. Pupil size data should be meters). Updated tutorials so demonstrate how to convert data to SI units before analyses (:gh:`12846`` by `Scott Huberty`_) \ No newline at end of file diff --git a/examples/visualization/eyetracking_plot_heatmap.py b/examples/visualization/eyetracking_plot_heatmap.py index a57857f34ad..07983685b5e 100644 --- a/examples/visualization/eyetracking_plot_heatmap.py +++ b/examples/visualization/eyetracking_plot_heatmap.py @@ -68,6 +68,10 @@ cmap = plt.get_cmap("viridis") plot_gaze(epochs["natural"], calibration=calibration, cmap=cmap, sigma=50) +# %% +# .. note:: The (0, 0) pixel coordinates are at the top-left of the +# trackable area of the screen for many eye trackers. + # %% # Overlaying plots with images # ---------------------------- diff --git a/mne/defaults.py b/mne/defaults.py index a2dd2a05250..a2769d79e6f 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -64,8 +64,8 @@ whitened="Z", gsr="S", temperature="C", - eyegaze="AU", - pupil="AU", + eyegaze="rad", + pupil="M", ), units=dict( mag="fT", @@ -92,8 +92,8 @@ whitened="Z", gsr="S", temperature="C", - eyegaze="AU", - pupil="AU", + eyegaze="rad", + pupil="µM", ), # scalings for the units scalings=dict( @@ -122,7 +122,7 @@ gsr=1.0, temperature=1.0, eyegaze=1.0, - pupil=1.0, + pupil=1e6, ), # rough guess for a good plot scalings_plot_raw=dict( @@ -156,8 +156,8 @@ gof=1e2, gsr=1.0, temperature=0.1, - eyegaze=3e-1, - pupil=1e3, + eyegaze=2e-1, + pupil=10e-6, ), scalings_cov_rank=dict( mag=1e12, @@ -183,8 +183,8 @@ hbo=(0, 20), hbr=(0, 20), csd=(-50.0, 50.0), - eyegaze=(0.0, 5000.0), - pupil=(0.0, 5000.0), + eyegaze=(-1, 1), + pupil=(0.0, 20), ), titles=dict( mag="Magnetometers", diff --git a/tutorials/io/70_reading_eyetracking_data.py b/tutorials/io/70_reading_eyetracking_data.py index ce4fcf41d9b..3cf72719e4c 100644 --- a/tutorials/io/70_reading_eyetracking_data.py +++ b/tutorials/io/70_reading_eyetracking_data.py @@ -78,29 +78,43 @@ new line, the y-coordinate *increased*, which is why the ``ypos_right`` channel in the plot below increases over time (for example, at about 4-seconds, and at about 8-seconds). + +.. seealso:: + + :ref:`tut-eyetrack` """ # %% -from mne.datasets import misc -from mne.io import read_raw_eyelink +import mne # %% -fpath = misc.data_path() / "eyetracking" / "eyelink" -raw = read_raw_eyelink(fpath / "px_textpage_ws.asc", create_annotations=["blinks"]) -custom_scalings = dict(eyegaze=1e3) -raw.pick(picks="eyetrack").plot(scalings=custom_scalings) +fpath = mne.datasets.misc.data_path() / "eyetracking" / "eyelink" +fname = fpath / "px_textpage_ws.asc" +raw = mne.io.read_raw_eyelink(fname, create_annotations=["blinks"]) +cal = mne.preprocessing.eyetracking.read_eyelink_calibration( + fname, + screen_distance=0.7, + screen_size=(0.53, 0.3), + screen_resolution=(1920, 1080), +)[0] +mne.preprocessing.eyetracking.convert_units(raw, calibration=cal, to="radians") +# %% +# Visualizing the data +# ^^^^^^^^^^^^^^^^^^^^ # %% -# .. important:: The (0, 0) pixel coordinates are at the top-left of the -# trackable area of the screen. Gaze towards lower areas of the -# the screen will yield a relatively higher y-coordinate. -# -# Note that we passed a custom `dict` to the ``'scalings'`` argument of -# `mne.io.Raw.plot`. This is because MNE's default plot scalings for eye -# position data are calibrated for HREF data, which are stored in radians -# (read below). +cal.plot() +# %% +custom_scalings = dict(pupil=1e3) +raw.pick(picks="eyetrack").plot(scalings=custom_scalings) + +# %% +# Note that we passed a custom `dict` to the ``'scalings'`` argument of +# `mne.io.Raw.plot`. This is because MNE expects the data to be in SI units +# (radians for eyegaze data, and meters for pupil size data), but we did not convert +# the pupil size data in this example. # %% # Head-Referenced Eye Angle (HREF) @@ -124,9 +138,11 @@ # %% -fpath = misc.data_path() / "eyetracking" / "eyelink" -raw = read_raw_eyelink(fpath / "HREF_textpage_ws.asc", create_annotations=["blinks"]) -raw.pick(picks="eyetrack").plot() +fpath = mne.datasets.misc.data_path() / "eyetracking" / "eyelink" +fname_href = fpath / "HREF_textpage_ws.asc" +raw = mne.io.read_raw_eyelink(fname_href, create_annotations=["blinks"]) +custom_scalings = dict(pupil=1e3) +raw.pick(picks="eyetrack").plot(scalings=custom_scalings) # %% # Pupil Position diff --git a/tutorials/preprocessing/90_eyetracking_data.py b/tutorials/preprocessing/90_eyetracking_data.py index bad4eeeda67..85f5d80bf82 100644 --- a/tutorials/preprocessing/90_eyetracking_data.py +++ b/tutorials/preprocessing/90_eyetracking_data.py @@ -99,16 +99,32 @@ first_cal.plot() +# %% +# Standardizing eyetracking data to SI units +# ------------------------------------------ +# +# EyeLink stores eyegaze positions in pixels, and pupil size in arbitrary units. +# MNE-Python expects eyegaze positions to be in radians of visual angle, and pupil +# size to be in meters. We can convert the eyegaze positions to radians using +# :func:`~mne.preprocessing.eyetracking.convert_units`. We'll pass the calibration +# object we created above, after specifying the screen resolution, screen size, and +# screen distance. + +first_cal["screen_resolution"] = (1920, 1080) +first_cal["screen_size"] = (0.53, 0.3) +first_cal["screen_distance"] = 0.9 +mne.preprocessing.eyetracking.convert_units(raw_et, calibration=first_cal, to="radians") + # %% # Plot the raw eye-tracking data # ------------------------------ # -# Let's plot the raw eye-tracking data. We'll pass a custom `dict` into -# the scalings argument to make the eyegaze channel traces legible when plotting, -# since this file contains pixel position data (as opposed to eye angles, -# which are reported in radians). +# Let's plot the raw eye-tracking data. Since we did not convert the pupil size to +# meters, we'll pass a custom `dict` into the scalings argument to make the pupil size +# traces legible when plotting. -raw_et.plot(scalings=dict(eyegaze=1e3)) +ps_scalings = dict(pupil=1e3) +raw_et.plot(scalings=ps_scalings) # %% # Handling blink artifacts @@ -189,7 +205,13 @@ picks_idx = mne.pick_channels( raw_et.ch_names, frontal + occipital + pupil, ordered=True ) -raw_et.plot(events=et_events, event_id=event_dict, event_color="g", order=picks_idx) +raw_et.plot( + events=et_events, + event_id=event_dict, + event_color="g", + order=picks_idx, + scalings=ps_scalings, +) # %% @@ -203,14 +225,16 @@ raw_et, events=et_events, event_id=event_dict, tmin=-0.3, tmax=3, baseline=None ) del raw_et # free up some memory -epochs[:8].plot(events=et_events, event_id=event_dict, order=picks_idx) +epochs[:8].plot( + events=et_events, event_id=event_dict, order=picks_idx, scalings=ps_scalings +) # %% # For this experiment, the participant was instructed to fixate on a crosshair in the # center of the screen. Let's plot the gaze position data to confirm that the # participant primarily kept their gaze fixated at the center of the screen. -plot_gaze(epochs, width=1920, height=1080) +plot_gaze(epochs, calibration=first_cal) # %% # .. seealso:: :ref:`tut-eyetrack-heatmap` From e999e853b263c8b48a6fcfeb60c82b445e20d88b Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Tue, 17 Sep 2024 01:29:01 +0200 Subject: [PATCH 16/55] [BUG] Allow `Epochs.compute_tfr()` for the multitaper method and complex/phase outputs (#12842) --- .github/workflows/tests.yml | 2 +- doc/changes/devel/12842.bugfix.rst | 1 + mne/decoding/receptive_field.py | 9 +++++---- mne/decoding/tests/test_search_light.py | 1 + mne/decoding/time_delaying_ridge.py | 22 +++++++++++++--------- mne/time_frequency/tests/test_tfr.py | 7 +++++++ mne/time_frequency/tfr.py | 3 ++- 7 files changed, 30 insertions(+), 15 deletions(-) create mode 100644 doc/changes/devel/12842.bugfix.rst diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 571c3943ae7..571d4329831 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -131,4 +131,4 @@ jobs: - uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - if: success() + if: always() diff --git a/doc/changes/devel/12842.bugfix.rst b/doc/changes/devel/12842.bugfix.rst new file mode 100644 index 00000000000..75f83683b8f --- /dev/null +++ b/doc/changes/devel/12842.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :meth:`mne.Epochs.compute_tfr` could not be used with the multitaper method and complex or phase outputs, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index 5fc985a81b3..a9cc72d18ce 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -117,7 +117,7 @@ def __init__( ): self.tmin = tmin self.tmax = tmax - self.sfreq = float(sfreq) + self.sfreq = sfreq self.feature_names = feature_names self.estimator = 0.0 if estimator is None else estimator self.fit_intercept = fit_intercept @@ -154,7 +154,7 @@ def _delay_and_reshape(self, X, y=None): X, self.tmin, self.tmax, - self.sfreq, + self.sfreq_, fill_mean=self.fit_intercept_, ) X = _reshape_for_est(X) @@ -182,12 +182,13 @@ def fit(self, X, y): raise ValueError( f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring} " ) + self.sfreq_ = float(self.sfreq) X, y, _, self._y_dim = self._check_dimensions(X, y) if self.tmin > self.tmax: raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})") # Initialize delays - self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq) + self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq_) # Define the slice that we should use in the middle self.valid_samples_ = _delays_to_slice(self.delays_) @@ -200,7 +201,7 @@ def fit(self, X, y): estimator = TimeDelayingRidge( self.tmin, self.tmax, - self.sfreq, + self.sfreq_, alpha=self.estimator, fit_intercept=self.fit_intercept_, n_jobs=self.n_jobs, diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index fe605abca06..9e15a1df59b 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -358,6 +358,7 @@ def test_sklearn_compliance(estimator, check): "check_transformer_data_not_an_array", "check_n_features_in", "check_fit2d_predict1d", + "check_do_not_raise_errors_in_init_or_set_params", ) if any(ignore in str(check) for ignore in ignores): return diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index e824a15be75..b80b36d3922 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -287,12 +287,10 @@ def __init__( n_jobs=None, edge_correction=True, ): - if tmin > tmax: - raise ValueError(f"tmin must be <= tmax, got {tmin} and {tmax}") - self.tmin = float(tmin) - self.tmax = float(tmax) - self.sfreq = float(sfreq) - self.alpha = float(alpha) + self.tmin = tmin + self.tmax = tmax + self.sfreq = sfreq + self.alpha = alpha self.reg_type = reg_type self.fit_intercept = fit_intercept self.edge_correction = edge_correction @@ -300,11 +298,11 @@ def __init__( @property def _smin(self): - return int(round(self.tmin * self.sfreq)) + return int(round(self.tmin_ * self.sfreq_)) @property def _smax(self): - return int(round(self.tmax * self.sfreq)) + 1 + return int(round(self.tmax_ * self.sfreq_)) + 1 def fit(self, X, y): """Estimate the coefficients of the linear model. @@ -323,6 +321,12 @@ def fit(self, X, y): """ _validate_type(X, "array-like", "X") _validate_type(y, "array-like", "y") + self.tmin_ = float(self.tmin) + self.tmax_ = float(self.tmax) + self.sfreq_ = float(self.sfreq) + self.alpha_ = float(self.alpha) + if self.tmin_ > self.tmax_: + raise ValueError(f"tmin must be <= tmax, got {self.tmin_} and {self.tmax_}") X = np.asarray(X, dtype=float) y = np.asarray(y, dtype=float) if X.ndim == 3: @@ -349,7 +353,7 @@ def fit(self, X, y): self.edge_correction, ) self.coef_ = _fit_corrs( - self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha, n_ch_x + self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha_, n_ch_x ) # This is the sklearn formula from LinearModel (will be 0. for no fit) if self.fit_intercept: diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 38099d8a3aa..cd3a97ab90a 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -1530,6 +1530,13 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): assert tfr.comment == "1" +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): + """Test Epochs.compute_tfr(output="complex"/"phase").""" + tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) + assert len(tfr.shape) == 5 + + @pytest.mark.parametrize("copy", (False, True)) def test_epochstfr_iter_evoked(epochs_tfr, copy): """Test EpochsTFR.iter_evoked().""" diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 98af67ff0a7..eaf173092bb 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1533,7 +1533,8 @@ def _compute_tfr(self, data, n_jobs, verbose): ] # deal with the "taper" dimension if self._needs_taper_dim: - expected_shape.insert(1, self._data.shape[1]) + tapers_dim = 1 if _get_instance_type_string(self) != "Epochs" else 2 + expected_shape.insert(1, self._data.shape[tapers_dim]) self._shape = tuple(expected_shape) @verbose From efbd97292592e92c2bdfe09e2fea9cf9b1d2e6f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Sep 2024 01:57:03 +0000 Subject: [PATCH 17/55] [pre-commit.ci] pre-commit autoupdate (#12856) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c369dc630ea..dfe129eb803 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff name: ruff lint mne From 4f58e814b4e5d01311f359f3d8f8e2cbca391fbb Mon Sep 17 00:00:00 2001 From: Scott Huberty <52462026+scott-huberty@users.noreply.github.com> Date: Wed, 18 Sep 2024 10:37:44 -0700 Subject: [PATCH 18/55] FIX: use meters for pupil SI, not moles (#12850) Co-authored-by: Stefan Appelhoff --- mne/defaults.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne/defaults.py b/mne/defaults.py index a2769d79e6f..d5aab1a8d38 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -65,7 +65,7 @@ gsr="S", temperature="C", eyegaze="rad", - pupil="M", + pupil="m", ), units=dict( mag="fT", @@ -93,7 +93,7 @@ gsr="S", temperature="C", eyegaze="rad", - pupil="µM", + pupil="µm", ), # scalings for the units scalings=dict( @@ -157,7 +157,7 @@ gsr=1.0, temperature=0.1, eyegaze=2e-1, - pupil=10e-6, + pupil=1e-2, ), scalings_cov_rank=dict( mag=1e12, @@ -184,7 +184,7 @@ hbr=(0, 20), csd=(-50.0, 50.0), eyegaze=(-1, 1), - pupil=(0.0, 20), + pupil=(-1.0, 1.0), ), titles=dict( mag="Magnetometers", From 2cab0257a430127c4bae348380ab2be99bffb412 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Wed, 18 Sep 2024 20:20:30 +0200 Subject: [PATCH 19/55] [BUG] Fix ignored `colorbar` flag in `plot_psds_topomap()` (#12853) --- doc/changes/devel/12853.bugfix.rst | 1 + mne/viz/tests/test_topomap.py | 18 ++++++++++++++++++ mne/viz/topomap.py | 4 ++-- 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 doc/changes/devel/12853.bugfix.rst diff --git a/doc/changes/devel/12853.bugfix.rst b/doc/changes/devel/12853.bugfix.rst new file mode 100644 index 00000000000..18c8afbb8ea --- /dev/null +++ b/doc/changes/devel/12853.bugfix.rst @@ -0,0 +1 @@ +Prevent the ``colorbar`` parameter being ignored in topomap plots such as :meth:`mne.time_frequency.Spectrum.plot_topomap`, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index 859699620ab..afa9341c00e 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -568,6 +568,24 @@ def patch(): assert_array_equal(evoked_grad.info["bads"], orig_bads) +def test_plot_psds_topomap_colorbar(): + """Test plot_psds_topomap colorbar option.""" + raw = read_raw_fif(raw_fname) + picks = pick_types(raw.info, meg="grad") + info = pick_info(raw.info, picks) + freqs = np.arange(3.0, 9.5) + rng = np.random.default_rng(42) + psd = np.abs(rng.standard_normal((len(picks), len(freqs)))) + bands = {"theta": [4, 8]} + + plt.close("all") + fig_cbar = plot_psds_topomap(psd, freqs, info, colorbar=True, bands=bands) + assert len(fig_cbar.axes) == 2 + + fig_nocbar = plot_psds_topomap(psd, freqs, info, colorbar=False, bands=bands) + assert len(fig_nocbar.axes) == 1 + + def test_plot_tfr_topomap(): """Test plotting of TFR data.""" raw = read_raw_fif(raw_fname) diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index c0e8b073447..147919a9c9d 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -2835,7 +2835,7 @@ def plot_psds_topomap( for ax, _mask, _data, (title, (fmin, fmax)) in zip( axes, freq_masks, band_data, bands.items() ): - colorbar = (not joint_vlim) or ax == axes[-1] + plot_colorbar = False if not colorbar else (not joint_vlim) or ax == axes[-1] _plot_topomap_multi_cbar( _data, pos, @@ -2844,7 +2844,7 @@ def plot_psds_topomap( vlim=vlim, cmap=cmap, outlines=outlines, - colorbar=colorbar, + colorbar=plot_colorbar, unit=unit, cbar_fmt=cbar_fmt, sphere=sphere, From efe8b6a9ce87c9a45dcdc20a8c44e99d2efb6c13 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 18 Sep 2024 14:48:24 -0400 Subject: [PATCH 20/55] DOC: Document installer uninstallation (#12845) --- doc/Makefile | 3 + doc/_static/js/set_installer_tab.js | 35 ++++++++---- doc/_static/js/update_installer_version.js | 2 +- doc/install/installers.rst | 64 ++++++++++++++++++++-- 4 files changed, 87 insertions(+), 17 deletions(-) diff --git a/doc/Makefile b/doc/Makefile index ab8219473b0..c0badd0bd9b 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -76,3 +76,6 @@ view: @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/sg_execution_times.html')" show: view + +serve: + python -m http.server -d _build/html diff --git a/doc/_static/js/set_installer_tab.js b/doc/_static/js/set_installer_tab.js index 5b6b737a565..3b4c043387a 100644 --- a/doc/_static/js/set_installer_tab.js +++ b/doc/_static/js/set_installer_tab.js @@ -12,21 +12,34 @@ function setTabs() { } if (navigator.userAgent.indexOf("Mac") !== -1) { // there's no good way to distinguish intel vs M1 in javascript so we - // just default to showing the first of the 2 macOS tabs - platform = "macos-intel"; + // just default to showing the most modern macOS installer + platform = "macos-apple"; } - let all_tab_nodes = document.querySelectorAll( - '.platform-selector-tabset')[0].children; - let input_nodes = [...all_tab_nodes].filter( - child => child.nodeName === "INPUT"); + var platform_short = platform.split("-")[0]; + let tab_label_nodes = [...document.querySelectorAll('.sd-tab-label')]; - let correct_label = tab_label_nodes.filter( + + let install_tab_nodes = document.querySelectorAll( + '.install-selector-tabset')[0].children; + let install_input_nodes = [...install_tab_nodes].filter( + child => child.nodeName === "INPUT"); + let install_label = tab_label_nodes.filter( // label.id is drawn from :name: property in the rST, which must // be unique across the whole site (*sigh*) - label => label.id.startsWith(platform))[0]; - let input_id = correct_label.getAttribute('for'); - let correct_input = input_nodes.filter(node => node.id === input_id)[0]; - correct_input.checked = true; + label => label.id.startsWith(`install-${platform}`))[0]; + let install_id = install_label.getAttribute('for'); + let install_input = install_input_nodes.filter(node => node.id === install_id)[0]; + install_input.checked = true; + + let uninstall_tab_nodes = document.querySelectorAll( + '.uninstall-selector-tabset')[0].children; + let uninstall_input_nodes = [...uninstall_tab_nodes].filter( + child => child.nodeName === "INPUT"); + let uninstall_label = tab_label_nodes.filter( + label => label.id.startsWith(`uninstall-${platform_short}`))[0]; + let uninstall_id = uninstall_label.getAttribute('for'); + let uninstall_input = uninstall_input_nodes.filter(node => node.id === uninstall_id)[0]; + uninstall_input.checked = true; } documentReady(setTabs); diff --git a/doc/_static/js/update_installer_version.js b/doc/_static/js/update_installer_version.js index 7cb8bdede1e..ad18890caf7 100644 --- a/doc/_static/js/update_installer_version.js +++ b/doc/_static/js/update_installer_version.js @@ -54,7 +54,7 @@ async function warnVersion() { title.innerText = "Warning"; inner.innerText = warn; outer.append(title, inner); - document.querySelectorAll('.platform-selector-tabset')[0].before(outer); + document.querySelectorAll('.install-selector-tabset')[0].before(outer); } } diff --git a/doc/install/installers.rst b/doc/install/installers.rst index 9f7932e911d..1aa352edcfb 100644 --- a/doc/install/installers.rst +++ b/doc/install/installers.rst @@ -7,12 +7,15 @@ MNE-Python installers are the easiest way to install MNE-Python and all dependencies. They also provide many additional Python packages and tools. Got any questions? Let us know on the `MNE Forum`_! +Platform-specific installers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + .. tab-set:: - :class: platform-selector-tabset + :class: install-selector-tabset .. tab-item:: Linux :class-content: text-center - :name: linux-installers + :name: install-linux .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-Linux.sh :ref-type: ref @@ -33,7 +36,7 @@ Python packages and tools. Got any questions? Let us know on the `MNE Forum`_! .. tab-item:: macOS (Intel) :class-content: text-center - :name: macos-intel-installers + :name: install-macos-intel .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-macOS_Intel.pkg :ref-type: ref @@ -49,7 +52,7 @@ Python packages and tools. Got any questions? Let us know on the `MNE Forum`_! .. tab-item:: macOS (Apple Silicon) :class-content: text-center - :name: macos-apple-installers + :name: install-macos-apple .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-macOS_M1.pkg :ref-type: ref @@ -65,7 +68,7 @@ Python packages and tools. Got any questions? Let us know on the `MNE Forum`_! .. tab-item:: Windows :class-content: text-center - :name: windows-installers + :name: install-windows .. button-link:: https://github.com/mne-tools/mne-installers/releases/download/v1.8.0/MNE-Python-1.8.0_0-Windows.exe :ref-type: ref @@ -107,3 +110,54 @@ bundles to the ``Applications`` folder on macOS. applications to start, especially on the very first run – which may take particularly long on Apple Silicon-based computers. Subsequent runs should usually be much faster. + +Uninstallation +^^^^^^^^^^^^^^ + +To remove the MNE-Python distribution provided by our installers above: + +1. Remove relevant lines from your shell initialization scripts if you + added them at installation time. To do this, you can run from the MNE Prompt: + + .. code-block:: bash + + $ conda init --reverse + + Or you can manually edit shell initialization scripts, e.g., ``~/.bashrc`` or + ``~/.bash_profile``. + +2. Follow the instructions below to remove the MNE-Python conda installation for your platform: + + .. tab-set:: + :class: uninstall-selector-tabset + + .. tab-item:: Linux + :name: uninstall-linux + + In a BASH terminal you can do: + + .. code-block:: bash + + $ which python + /home/username/mne-python/1.8.0_0/bin/python + $ rm -Rf /home/$USER/mne-python + $ rm /home/$USER/.local/share/applications/mne-python-*.desktop + + .. tab-item:: macOS + :name: uninstall-macos + + You can simply `drag the MNE-Python folder to the trash in the Finder `__. + + Alternatively, you can do something like: + + .. code-block:: bash + + $ which python + /Users/username/Applications/MNE-Python/1.8.0_0/.mne-python/bin/python + $ rm -Rf /Users/$USER/Applications/MNE-Python # if user-specific + $ rm -Rf /Applications/MNE-Python # if system-wide + + .. tab-item:: Windows + :name: uninstall-windows + + To uninstall MNE-Python, you can remove the application using the `Windows Control Panel `__. From fa841cbc30ce830e40c60a8a44518b84e2f7f0cf Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Thu, 19 Sep 2024 20:08:27 +0200 Subject: [PATCH 21/55] DOC: impedances may be obtained for BrainVision (#12861) Co-authored-by: Clemens Brunner --- mne/io/brainvision/brainvision.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mne/io/brainvision/brainvision.py b/mne/io/brainvision/brainvision.py index 28f3db84a29..e58e04e07e6 100644 --- a/mne/io/brainvision/brainvision.py +++ b/mne/io/brainvision/brainvision.py @@ -63,6 +63,12 @@ class RawBrainVision(BaseRaw): Notes ----- + If the BrainVision header file contains impedance measurements, these may be + accessed using ``raw.impedances`` after reading using this function. However, + this attribute will NOT be available after a save and re-load of the data. + That is, it is only available when reading data directly from the BrainVision + header file. + BrainVision markers consist of a type and a description (in addition to other fields like onset and duration). In contrast, annotations in MNE only have a description. Therefore, a BrainVision marker of type "Stimulus" and description "S 1" will be @@ -977,6 +983,12 @@ def read_raw_brainvision( Notes ----- + If the BrainVision header file contains impedance measurements, these may be + accessed using ``raw.impedances`` after reading using this function. However, + this attribute will NOT be available after a save and re-load of the data. + That is, it is only available when reading data directly from the BrainVision + header file. + BrainVision markers consist of a type and a description (in addition to other fields like onset and duration). In contrast, annotations in MNE only have a description. Therefore, a BrainVision marker of type "Stimulus" and description "S 1" will be From e6f9c5df856418d331684dd3477e8c1d1c0f37e1 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 20 Sep 2024 16:51:54 +0200 Subject: [PATCH 22/55] [MAINT] Update minimum Python version to 3.10 in `environment.yml` (#12865) --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index a0dbdf5ec49..64509188d47 100644 --- a/environment.yml +++ b/environment.yml @@ -2,7 +2,7 @@ name: mne channels: - conda-forge dependencies: - - python>=3.9 + - python>=3.10 - pip - numpy - scipy From 6f9646ff979294683bb385e63f3587aa3c7a6dd5 Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Fri, 20 Sep 2024 17:58:44 +0200 Subject: [PATCH 23/55] DOC: ICA -> fix typo, add func ref, use list (#12860) Co-authored-by: Clemens Brunner --- mne/decoding/tests/test_base.py | 6 +++++- mne/preprocessing/ica.py | 15 ++++++++------- mne/utils/progressbar.py | 6 +++++- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 25fbba3fafd..0930d007d28 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import platform from contextlib import nullcontext import numpy as np @@ -312,6 +313,7 @@ def test_get_coef_multiclass(n_features, n_targets): ) # TODO: Need to fix this properly in LinearModel @pytest.mark.filterwarnings("ignore:'multi_class' was deprecated in.*:FutureWarning") +@pytest.mark.filterwarnings("ignore:lbfgs failed to converge.*:") def test_get_coef_multiclass_full(n_classes, n_channels, n_times): """Test a full example with pattern extraction.""" data = np.zeros((10 * n_classes, n_channels, n_times)) @@ -339,7 +341,9 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times): if n_times > 1: want += (n_times, n_times) assert scores.shape == want - assert_array_less(0.8, scores) + # On Windows LBFGS can fail to converge, so we need to be a bit more tol here + limit = 0.7 if platform.system() == "Windows" else 0.8 + assert_array_less(limit, scores) clf.fit(X, y) patterns = get_coef(clf, "patterns_", inverse_transform=True) assert patterns.shape == (n_classes, n_channels, n_times) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index eb4c0a663c4..b36ebfa9119 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -1939,14 +1939,15 @@ def find_bads_muscle( sphere=None, verbose=None, ): - """Detect muscle related components. + """Detect muscle-related components. Detection is based on :footcite:`DharmapraniEtAl2016` which uses data from a subject who has been temporarily paralyzed :footcite:`WhithamEtAl2007`. The criteria are threefold: - 1) Positive log-log spectral slope from 7 to 45 Hz - 2) Peripheral component power (farthest away from the vertex) - 3) A single focal point measured by low spatial smoothness + + #. Positive log-log spectral slope from 7 to 45 Hz + #. Peripheral component power (farthest away from the vertex) + #. A single focal point measured by low spatial smoothness The threshold is relative to the slope, focal point and smoothness of a typical muscle-related ICA component. Note the high frequency @@ -1970,14 +1971,14 @@ def find_bads_muscle( l_freq : float Low frequency for muscle-related power. h_freq : float - High frequency for msucle related power. + High frequency for muscle-related power. %(sphere_topomap_auto)s %(verbose)s Returns ------- muscle_idx : list of int - The indices of EOG related components, sorted by score. + The indices of muscle-related components, sorted by score. scores : np.ndarray of float, shape (``n_components_``) | list of array The correlation scores. @@ -2105,7 +2106,7 @@ def find_bads_eog( See Also -------- - find_bads_ecg, find_bads_ref + find_bads_ecg, find_bads_ref, find_bads_muscle """ _validate_type(threshold, (str, "numeric"), "threshold") if isinstance(threshold, str): diff --git a/mne/utils/progressbar.py b/mne/utils/progressbar.py index 4ada836ade7..b33efa5b72a 100644 --- a/mne/utils/progressbar.py +++ b/mne/utils/progressbar.py @@ -180,7 +180,11 @@ def __exit__(self, type_, value, traceback): # noqa: D105 self._thread.join() self._mmap = None if op.isfile(self._mmap_fname): - os.remove(self._mmap_fname) + try: + os.remove(self._mmap_fname) + # happens on Windows sometimes + except PermissionError: # pragma: no cover + pass def __del__(self): """Ensure output completes.""" From a218f96927ef06fd8a4e38363b9d98f74a93f1c3 Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Fri, 20 Sep 2024 18:07:33 +0200 Subject: [PATCH 24/55] scale figsize in plot_events depending on n unique events (#12844) --- doc/changes/devel/12844.other.rst | 1 + examples/datasets/limo_data.py | 2 +- examples/preprocessing/epochs_metadata.py | 4 +- mne/viz/misc.py | 37 ++++++++++++++++--- tutorials/clinical/60_sleep.py | 8 ++-- tutorials/intro/10_overview.py | 2 +- .../preprocessing/70_fnirs_processing.py | 8 ++-- tutorials/raw/20_event_arrays.py | 4 +- 8 files changed, 47 insertions(+), 19 deletions(-) create mode 100644 doc/changes/devel/12844.other.rst diff --git a/doc/changes/devel/12844.other.rst b/doc/changes/devel/12844.other.rst new file mode 100644 index 00000000000..ce959d8132a --- /dev/null +++ b/doc/changes/devel/12844.other.rst @@ -0,0 +1 @@ +Improve automatic figure scaling of :func:`mne.viz.plot_events`, and event_id and count overview legend when a high amount of unique events is supplied, by `Stefan Appelhoff`_. diff --git a/examples/datasets/limo_data.py b/examples/datasets/limo_data.py index 54a2f34a530..f7f6d58cf19 100644 --- a/examples/datasets/limo_data.py +++ b/examples/datasets/limo_data.py @@ -107,7 +107,7 @@ print(limo_epochs.metadata.head()) # %% -# Now let's take a closer look at the information in the epochs +# Now let us take a closer look at the information in the epochs # metadata. # We want include all columns in the summary table diff --git a/examples/preprocessing/epochs_metadata.py b/examples/preprocessing/epochs_metadata.py index d1ea9a85996..9c46368afa0 100644 --- a/examples/preprocessing/epochs_metadata.py +++ b/examples/preprocessing/epochs_metadata.py @@ -35,8 +35,8 @@ # # All experimental events are stored in the :class:`~mne.io.Raw` instance as # :class:`~mne.Annotations`. We first need to convert these to events and the -# corresponding mapping from event codes to event names (``event_id``). We then -# visualize the events. +# corresponding mapping from event codes to event names (``event_id``). +# We then visualize the events. all_events, all_event_id = mne.events_from_annotations(raw) mne.viz.plot_events(events=all_events, event_id=all_event_id, sfreq=raw.info["sfreq"]) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index af1345aa69d..34f0ad566bd 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -841,12 +841,18 @@ def plot_events( color = _handle_event_colors(color, unique_events, event_id) import matplotlib.pyplot as plt + unique_events_id = np.array(unique_events_id) + fig = None + figsize = plt.rcParams["figure.figsize"] + # assuming the user did not change matplotlib default params, the figsize of + # (6.4, 4.8) becomes too big if scaled beyond twice its size, so maximum 2 + _scaling = min(max(1, len(unique_events_id) / 10), 2) + figsize_scaled = np.array(figsize) * _scaling if axes is None: - fig = plt.figure(layout="constrained") + fig = plt.figure(layout="constrained", figsize=tuple(figsize_scaled)) ax = axes if axes else plt.gca() - unique_events_id = np.array(unique_events_id) min_event = np.min(unique_events_id) max_event = np.max(unique_events_id) max_x = ( @@ -861,9 +867,9 @@ def plot_events( continue y = np.full(count, idx + 1 if equal_spacing else events[ev_mask, 2][0]) if event_id is not None: - event_label = f"{event_id_rev[ev]} ({count})" + event_label = f"{event_id_rev[ev]}\n(id:{ev}; N:{count})" else: - event_label = f"N={count:d}" + event_label = f"id:{ev}; N:{count:d}" labels.append(event_label) kwargs = {} if ev in color: @@ -893,11 +899,32 @@ def plot_events( # reverse order so that the highest numbers are at the top # (match plot order) handles, labels = handles[::-1], labels[::-1] + + # spread legend entries over more columns, 25 still ~fit in one column + # (assuming non-user supplied fig) + ncols = int(np.ceil(len(unique_events_id) / 25)) + + # Make space for legend box = ax.get_position() factor = 0.8 if event_id is not None else 0.9 + factor -= 0.1 * (ncols - 1) ax.set_position([box.x0, box.y0, box.width * factor, box.height]) + + # Try some adjustments to squeeze as much information into the legend + # without cutting off the ends ax.legend( - handles, labels, loc="center left", bbox_to_anchor=(1, 0.5), fontsize="small" + handles, + labels, + loc="center left", + bbox_to_anchor=(1, 0.5), + fontsize="small", + borderpad=0, # default 0.4 + labelspacing=0.25, # default 0.5 + columnspacing=1.0, # default 2 + handletextpad=0, # default 0.8 + markerscale=2, # default 1 + borderaxespad=0.2, # default 0.5 + ncols=ncols, ) fig.canvas.draw() plt_show(show) diff --git a/tutorials/clinical/60_sleep.py b/tutorials/clinical/60_sleep.py index b25776d7435..e50b0740a7c 100644 --- a/tutorials/clinical/60_sleep.py +++ b/tutorials/clinical/60_sleep.py @@ -48,22 +48,22 @@ # Load the data # ------------- # -# Here we download the data from two subjects and the end goal is to obtain -# :term:`epochs` and its associated ground truth. +# Here we download the data of two subjects. The end goal is to obtain +# :term:`epochs` and the associated ground truth. # # MNE-Python provides us with # :func:`mne.datasets.sleep_physionet.age.fetch_data` to conveniently download # data from the Sleep Physionet dataset # :footcite:`KempEtAl2000,GoldbergerEtAl2000`. # Given a list of subjects and records, the fetcher downloads the data and -# provides us for each subject, a pair of files: +# provides us with a pair of files for each subject: # # * ``-PSG.edf`` containing the polysomnography. The :term:`raw` data from the # EEG helmet, # * ``-Hypnogram.edf`` containing the :term:`annotations` recorded by an # expert. # -# Combining these two in a :class:`mne.io.Raw` object then we can extract +# Combining these two in a :class:`mne.io.Raw` object will allow us to extract # :term:`events` based on the descriptions of the annotations to obtain the # :term:`epochs`. # diff --git a/tutorials/intro/10_overview.py b/tutorials/intro/10_overview.py index 89c1ccc3505..f61745b0024 100644 --- a/tutorials/intro/10_overview.py +++ b/tutorials/intro/10_overview.py @@ -208,7 +208,7 @@ # example of this is shown in the next section. There is also a convenient # `~mne.viz.plot_events` function for visualizing the distribution of events # across the duration of the recording (to make sure event detection worked as -# expected). Here we'll also make use of the `~mne.Info` attribute to get the +# expected). Here we will also make use of the `~mne.Info` attribute to get the # sampling frequency of the recording (so our x-axis will be in seconds instead # of in samples). diff --git a/tutorials/preprocessing/70_fnirs_processing.py b/tutorials/preprocessing/70_fnirs_processing.py index c7efa46c06c..ca1614321c3 100644 --- a/tutorials/preprocessing/70_fnirs_processing.py +++ b/tutorials/preprocessing/70_fnirs_processing.py @@ -172,7 +172,7 @@ # and the unwanted heart rate component has been removed, we can extract epochs # related to each of the experimental conditions. # -# First we extract the events of interest and visualise them to ensure they are +# First we extract the events of interest and visualize them to ensure they are # correct. events, event_dict = mne.events_from_annotations(raw_haemo) @@ -181,7 +181,7 @@ # %% # Next we define the range of our epochs, the rejection criteria, -# baseline correction, and extract the epochs. We visualise the log of which +# baseline correction, and extract the epochs. We visualize the log of which # epochs were dropped. reject_criteria = dict(hbo=80e-6) @@ -209,7 +209,7 @@ # ------------------------------------------- # # Now we can view the haemodynamic response for our tapping condition. -# We visualise the response for both the oxy- and deoxyhaemoglobin, and +# We visualize the response for both the oxy- and deoxyhaemoglobin, and # observe the expected peak in HbO at around 6 seconds consistently across # trials, and the consistent dip in HbR that is slightly delayed relative to # the HbO peak. @@ -296,7 +296,7 @@ # --------------------------------------- # # Finally we generate topo maps for the left and right conditions to view -# the location of activity. First we visualise the HbO activity. +# the location of activity. First we visualize the HbO activity. times = np.arange(4.0, 11.0, 1.0) epochs["Tapping/Left"].average(picks="hbo").plot_topomap(times=times, **topomap_args) diff --git a/tutorials/raw/20_event_arrays.py b/tutorials/raw/20_event_arrays.py index 41e8829f91d..ecc12dc7cff 100644 --- a/tutorials/raw/20_event_arrays.py +++ b/tutorials/raw/20_event_arrays.py @@ -158,8 +158,8 @@ # :func:`mne.viz.plot_events` will plot each event versus its sample number # (or, if you provide the sampling frequency, it will plot them versus time in # seconds). It can also account for the offset between sample number and sample -# index in Neuromag systems, with the ``first_samp`` parameter. If an event -# dictionary is provided, it will be used to generate a legend: +# index in Neuromag systems, with the ``first_samp`` parameter. +# If an event dictionary is provided, it will be used to generate a legend: fig = mne.viz.plot_events( events, sfreq=raw.info["sfreq"], first_samp=raw.first_samp, event_id=event_dict From b9cdca8638b20768e85822f5e5e57a5fe043d393 Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Fri, 20 Sep 2024 22:04:37 +0200 Subject: [PATCH 25/55] Let users run `find_bads_muscle` also when no sensor positions are available (#12862) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Clemens Brunner Co-authored-by: Eric Larson --- doc/changes/devel/12862.other.rst | 1 + mne/preprocessing/ica.py | 51 ++++++++++++++++++++++++----- mne/preprocessing/tests/test_ica.py | 7 ++++ 3 files changed, 50 insertions(+), 9 deletions(-) create mode 100644 doc/changes/devel/12862.other.rst diff --git a/doc/changes/devel/12862.other.rst b/doc/changes/devel/12862.other.rst new file mode 100644 index 00000000000..393beeb8a8c --- /dev/null +++ b/doc/changes/devel/12862.other.rst @@ -0,0 +1 @@ +:meth:`mne.preprocessing.ICA.find_bads_muscle` can now be run when passing an ``inst`` without sensor positions. However, it will just use the first of three criteria (slope) to find muscle-related ICA components, by `Stefan Appelhoff`_. diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index b36ebfa9119..d1b0e1b0456 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -62,6 +62,7 @@ _PCA, Bunch, _check_all_same_channel_names, + _check_ch_locs, _check_compensation_grade, _check_fname, _check_on_missing, @@ -1955,6 +1956,9 @@ def find_bads_muscle( has been modified to 45 Hz as a default based on the criteria being more accurate in practice. + If ``inst`` is supplied without sensor positions, only the first criterion + (slope) is applied. + Parameters ---------- inst : instance of Raw, Epochs or Evoked @@ -1992,6 +1996,8 @@ def find_bads_muscle( """ _validate_type(threshold, "numeric", "threshold") + slope_score, focus_score, smoothness_score = None, None, None + sources = self.get_sources(inst, start=start, stop=stop) components = self.get_components() @@ -2002,11 +2008,32 @@ def find_bads_muscle( psds = psds.mean(axis=0) slopes = np.polyfit(np.log10(freqs), np.log10(psds).T, 1)[0] + # typical muscle slope is ~0.15, non-muscle components negative + # so logistic with shift -0.5 and slope 0.25 so -0.5 -> 0.5 and 0->1 + slope_score = expit((slopes + 0.5) / 0.25) + + # Need sensor positions for the criteria below, so return with only one score + # if no positions available + picks = _picks_to_idx( + inst.info, self.ch_names, "all", exclude=(), allow_empty=False + ) + if not _check_ch_locs(inst.info, picks=picks): + warn( + "No sensor positions found. Scores for bad muscle components are only " + "based on the 'slope' criterion." + ) + scores = slope_score + self.labels_["muscle"] = [ + idx for idx, score in enumerate(scores) if score > threshold + ] + return self.labels_["muscle"], scores + # compute metric #2: distance from the vertex of focus components_norm = abs(components) / np.max(abs(components), axis=0) # we need to retrieve the position from the channels that were used to # fit the ICA. N.B: picks in _find_topomap_coords includes bad channels # even if they are not provided explicitly. + pos = _find_topomap_coords( inst.info, picks=self.ch_names, sphere=sphere, ignore_overlap=True ) @@ -2016,6 +2043,10 @@ def find_bads_muscle( dists /= dists.max() focus_dists = np.dot(dists, components_norm) + # focus distance is ~65% of max electrode distance with 10% slope + # (assumes typical head size) + focus_score = expit((focus_dists - 0.65) / 0.1) + # compute metric #3: smoothness smoothnesses = np.zeros((components.shape[1],)) dists = distance.squareform(distance.pdist(pos)) @@ -2025,20 +2056,22 @@ def find_bads_muscle( comp_dists /= comp_dists.max() smoothnesses[idx] = np.multiply(dists, comp_dists).sum() - # typical muscle slope is ~0.15, non-muscle components negative - # so logistic with shift -0.5 and slope 0.25 so -0.5 -> 0.5 and 0->1 - slope_score = expit((slopes + 0.5) / 0.25) - # focus distance is ~65% of max electrode distance with 10% slope - # (assumes typical head size) - focus_score = expit((focus_dists - 0.65) / 0.1) # smoothnessness is around 150 for muscle and 450 otherwise # so use reversed logistic centered at 300 with 100 slope smoothness_score = 1 - expit((smoothnesses - 300) / 100) - # multiply so that all three components must be present - scores = slope_score * focus_score * smoothness_score + + # multiply all criteria that are present + scores = [ + score + for score in [slope_score, focus_score, smoothness_score] + if score is not None + ] + n_criteria = len(scores) + scores = np.prod(np.array(scores), axis=0) + # scale the threshold by the use of three metrics self.labels_["muscle"] = [ - idx for idx, score in enumerate(scores) if score > threshold**3 + idx for idx, score in enumerate(scores) if score > threshold**n_criteria ] return self.labels_["muscle"], scores diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 72cfa601e56..f9dd740c5ef 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -1520,6 +1520,13 @@ def test_ica_labels(): ica.find_bads_muscle(raw) assert "muscle" in ica.labels_ + # Try without sensor locations + raw.set_montage(None) + with pytest.warns(RuntimeWarning, match="based on the 'slope' criterion"): + labels, scores = ica.find_bads_muscle(raw, threshold=0.35) + assert "muscle" in ica.labels_ + assert labels == [3] + @testing.requires_testing_data @pytest.mark.parametrize( From fc05aeb19e7e998356d400d5e93d977545a0511a Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Mon, 23 Sep 2024 22:24:01 +0200 Subject: [PATCH 26/55] FIX: limit ncols to a reasonable max in plot_events (#12867) --- mne/viz/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/viz/misc.py b/mne/viz/misc.py index 34f0ad566bd..ed2636d3961 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -901,8 +901,8 @@ def plot_events( handles, labels = handles[::-1], labels[::-1] # spread legend entries over more columns, 25 still ~fit in one column - # (assuming non-user supplied fig) - ncols = int(np.ceil(len(unique_events_id) / 25)) + # (assuming non-user supplied fig), max at 3 columns + ncols = min(int(np.ceil(len(unique_events_id) / 25)), 3) # Make space for legend box = ax.get_position() From dcb05a22cff16539e5753496668fa0a0b14b6c9a Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Mon, 23 Sep 2024 15:30:55 -0500 Subject: [PATCH 27/55] fix HTML repr table rendering (#12788) --- doc/_static/style.css | 45 ++++- .../repr/_acquisition.html.jinja | 36 ++-- mne/html_templates/repr/_channels.html.jinja | 24 +-- mne/html_templates/repr/_filters.html.jinja | 22 +- mne/html_templates/repr/_general.html.jinja | 26 +-- mne/html_templates/repr/epochs.html.jinja | 2 +- mne/html_templates/repr/evoked.html.jinja | 2 +- mne/html_templates/repr/forward.html.jinja | 6 +- mne/html_templates/repr/ica.html.jinja | 2 +- mne/html_templates/repr/info.html.jinja | 2 +- .../repr/inverse_operator.html.jinja | 2 +- mne/html_templates/repr/raw.html.jinja | 2 +- mne/html_templates/repr/spectrum.html.jinja | 2 +- .../static/_section_header_row.html.jinja | 12 ++ mne/html_templates/repr/static/repr.css | 188 +++++++++--------- mne/html_templates/repr/static/repr.js | 46 ++--- mne/html_templates/repr/tfr.html.jinja | 2 +- tutorials/intro/70_report.py | 4 +- tutorials/io/70_reading_eyetracking_data.py | 2 +- tutorials/raw/20_event_arrays.py | 4 +- 20 files changed, 209 insertions(+), 222 deletions(-) create mode 100644 mne/html_templates/repr/static/_section_header_row.html.jinja diff --git a/doc/_static/style.css b/doc/_static/style.css index 4bc6e03708e..8e53c6a49eb 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -161,20 +161,45 @@ iframe.sg_report { top: 0; } -/* TODO: Either pydata-sphinx-theme (for using Bootstrap) or sphinx-gallery (for adding table formatting) should fix this */ -.table-striped-columns>:not(caption)>tr>:nth-child(2n),.table-striped>tbody>tr:nth-of-type(odd)>* { - --bs-table-accent-bg: var(--bs-table-striped-bg); +/* ******************************************************** HTML repr tables */ + +/* make table responsive to pydata-sphinx-theme's light/dark mode */ +.table > :not(caption) > * > * { color: var(--pst-color-text-base); } -.table-hover>tbody>tr:hover>* { - --bs-table-accent-bg: var(--bs-table-hover-bg); - color: var(--pst-color-text-base); +.mne-repr-table tbody tr:hover { + background-color: var(--pst-color-table-row-hover-bg); } -.rendered_html table { - color: var(--pst-color-text-base); +.mne-repr-section-toggle > button > svg > path { + fill: var(--pst-color-text-base); } - - +/* make the expand/collapse button look nicer */ +.mne-repr-section-toggle > button { + padding: 20%; +} +/* make section header rows more distinct (and harmonize with pydata-sphinx-theme table +style in the process). Color copied from pydata-sphinx-theme; 2px copied from bootstrap. +*/ +.mne-repr-table th { + border-bottom: 2px solid var(--pst-color-primary); +} +/* harmonize the channel names buttons with the rest of the table */ +.mne-ch-names-btn { + font-size: inherit; + padding: 0.25rem; + min-width: 1.5rem; + font-weight: bold; +} +/* +.mne-ch-names-btn:hover { + background-color: var(--pst-color-); + text-decoration: underline; +} +.mne-ch-names-btn:focus-visible { + outline: 0.1875rem solid var(--pst-color-accent); + outline-offset: 0.1875rem; +} +*/ /* ***************************************************** sphinx-design fixes */ p.btn a { color: unset; diff --git a/mne/html_templates/repr/_acquisition.html.jinja b/mne/html_templates/repr/_acquisition.html.jinja index c688107b0d1..0016740cdf8 100644 --- a/mne/html_templates/repr/_acquisition.html.jinja +++ b/mne/html_templates/repr/_acquisition.html.jinja @@ -3,33 +3,23 @@ {# Collapse content during documentation build. #} {% if collapsed %} -{% set collapsed_row_class = "repr-element-faded repr-element-collapsed" %} +{% set collapsed_row_class = "mne-repr-collapsed" %} {% else %} {% set collapsed_row_class = "" %} {% endif %} - - - - - - {{ section }} - - +{%include 'static/_section_header_row.html.jinja' %} + {% if duration %} - + Duration {{ duration }} (HH:MM:SS) {% endif %} {% if inst is defined and inst | has_attr("kind") and inst | has_attr("nave") %} - + Aggregation {% if inst.kind == "average" %} average of {{ inst.nave }} epochs @@ -42,21 +32,21 @@ {% endif %} {% if inst is defined and inst | has_attr("comment") %} - + Condition {{inst.comment}} {% endif %} {% if inst is defined and inst | has_attr("events") %} - + Total number of events {{ inst.events | length }} {% endif %} {% if event_counts is defined %} - + Events counts {% if events is not none %} @@ -72,35 +62,35 @@ {% endif %} {% if inst is defined and inst | has_attr("tmin") and inst | has_attr("tmax") %} - + Time range {{ inst | format_time_range }} {% endif %} {% if inst is defined and inst | has_attr("baseline") %} - + Baseline {{ inst | format_baseline }} {% endif %} {% if info["sfreq"] is defined and info["sfreq"] is not none %} - + Sampling frequency {{ "%0.2f" | format(info["sfreq"]) }} Hz {% endif %} {% if inst is defined and inst.times is defined %} - + Time points {{ inst.times | length | format_number }} {% endif %} {% if inst is defined and inst | has_attr("metadata") %} - + Metadata {{ inst | format_metadata }} diff --git a/mne/html_templates/repr/_channels.html.jinja b/mne/html_templates/repr/_channels.html.jinja index d6ccb09312d..4f7646e9c80 100644 --- a/mne/html_templates/repr/_channels.html.jinja +++ b/mne/html_templates/repr/_channels.html.jinja @@ -3,36 +3,26 @@ {# Collapse content during documentation build. #} {% if collapsed %} -{% set collapsed_row_class = "repr-element-faded repr-element-collapsed" %} +{% set collapsed_row_class = "mne-repr-collapsed" %} {% else %} {% set collapsed_row_class = "" %} {% endif %} - - - - - - {{ section }} - - +{%include 'static/_section_header_row.html.jinja' %} + {% for channel_type, channels in (info | format_channels).items() %} {% set channel_names_good = channels["good"] | map(attribute='name_html') | join(', ') %} - + {{ channel_type }} - {% if channels["bad"] %} {% set channel_names_bad = channels["bad"] | map(attribute='name_html') | join(', ') %} - and {% endif %} @@ -41,7 +31,7 @@ {% endfor %} - + Head & sensor digitization {% if info["dig"] is not none %} {{ info["dig"] | length }} points diff --git a/mne/html_templates/repr/_filters.html.jinja b/mne/html_templates/repr/_filters.html.jinja index b01841cf137..97ede5157c1 100644 --- a/mne/html_templates/repr/_filters.html.jinja +++ b/mne/html_templates/repr/_filters.html.jinja @@ -3,40 +3,30 @@ {# Collapse content during documentation build. #} {% if collapsed %} -{% set collapsed_row_class = "repr-element-faded repr-element-collapsed" %} +{% set collapsed_row_class = "mne-repr-collapsed" %} {% else %} {% set collapsed_row_class = "" %} {% endif %} - - - - - - {{ section }} - - +{%include 'static/_section_header_row.html.jinja' %} + {% if info["highpass"] is defined and info["highpass"] is not none %} - + Highpass {{ "%0.2f" | format(info["highpass"]) }} Hz {% endif %} {% if info["lowpass"] is defined and info["lowpass"] is not none %} - + Lowpass {{ "%0.2f" | format(info["lowpass"]) }} Hz {% endif %} {% if info.projs is defined and info.projs %} - + Projections {% for p in (info | format_projs) %} diff --git a/mne/html_templates/repr/_general.html.jinja b/mne/html_templates/repr/_general.html.jinja index c9ad8310e64..a57ae40049d 100644 --- a/mne/html_templates/repr/_general.html.jinja +++ b/mne/html_templates/repr/_general.html.jinja @@ -3,26 +3,16 @@ {# Collapse content during documentation build. #} {% if collapsed %} -{% set collapsed_row_class = "repr-element-faded repr-element-collapsed" %} +{% set collapsed_row_class = "mne-repr-collapsed" %} {% else %} {% set collapsed_row_class = "" %} {% endif %} - - - - - - {{ section }} - - +{%include 'static/_section_header_row.html.jinja' %} + {% if filenames %} - + Filename(s) {% for f in filenames %} @@ -33,12 +23,12 @@ {% endif %} - + MNE object type {{ inst | data_type }} - + Measurement date {% if info["meas_date"] is defined and info["meas_date"] is not none %} {{ info["meas_date"] | dt_to_str }} @@ -47,7 +37,7 @@ {% endif %} - + Participant {% if info["subject_info"] is defined and info["subject_info"] is not none %} {% if info["subject_info"]["his_id"] is defined %} @@ -58,7 +48,7 @@ {% endif %} - + Experimenter {% if info["experimenter"] is defined and info["experimenter"] is not none %} {{ info["experimenter"] }} diff --git a/mne/html_templates/repr/epochs.html.jinja b/mne/html_templates/repr/epochs.html.jinja index 991aa8de0e3..6b33c177e87 100644 --- a/mne/html_templates/repr/epochs.html.jinja +++ b/mne/html_templates/repr/epochs.html.jinja @@ -2,7 +2,7 @@ {% set info = inst.info %} - +
{%include '_general.html.jinja' %} {%include '_acquisition.html.jinja' %} {%include '_channels.html.jinja' %} diff --git a/mne/html_templates/repr/evoked.html.jinja b/mne/html_templates/repr/evoked.html.jinja index 991aa8de0e3..6b33c177e87 100644 --- a/mne/html_templates/repr/evoked.html.jinja +++ b/mne/html_templates/repr/evoked.html.jinja @@ -2,7 +2,7 @@ {% set info = inst.info %} -
+
{%include '_general.html.jinja' %} {%include '_acquisition.html.jinja' %} {%include '_channels.html.jinja' %} diff --git a/mne/html_templates/repr/forward.html.jinja b/mne/html_templates/repr/forward.html.jinja index 22be9248ecc..510a775c2b6 100644 --- a/mne/html_templates/repr/forward.html.jinja +++ b/mne/html_templates/repr/forward.html.jinja @@ -1,12 +1,12 @@ {%include '_js_and_css.html.jinja' %} -
+
{% for channel_type, channels in (info | format_channels).items() %} {% set channel_names_good = channels["good"] | map(attribute='name_html') | join(', ') %} + + + + diff --git a/tutorials/forward/35_eeg_no_mri.py b/tutorials/forward/35_eeg_no_mri.py index a2deaa069b6..0fca916cabf 100644 --- a/tutorials/forward/35_eeg_no_mri.py +++ b/tutorials/forward/35_eeg_no_mri.py @@ -82,7 +82,7 @@ fwd = mne.make_forward_solution( raw.info, trans=trans, src=src, bem=bem, eeg=True, mindist=5.0, n_jobs=None ) -print(fwd) +fwd ############################################################################## # From here on, standard inverse imaging methods can be used! From d4c0d1c4062110db8c1af0bee47df969d8968d8b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 7 Oct 2024 12:06:23 +0200 Subject: [PATCH 43/55] Update spacing for comments in pyproject.toml (#12886) --- pyproject.toml | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 31770fbff47..d58793a6e45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ full-no-qt = [ "numba", "openmeeg>=2.5.5", "pandas", - "pyarrow", # only needed to avoid a deprecation warning in pandas + "pyarrow", # only needed to avoid a deprecation warning in pandas "pybv", "pyobjc-framework-Cocoa>=5.2.0; platform_system=='Darwin'", "python-picard", @@ -179,7 +179,7 @@ Homepage = "https://mne.tools/" "Source Code" = "https://github.com/mne-tools/mne-python/" [tool.bandit.assert_used] -skips = ["*/test_*.py"] # assert statements are good practice with pytest +skips = ["*/test_*.py"] # assert statements are good practice with pytest [tool.changelog-bot] @@ -211,7 +211,7 @@ exclude = [ "/mne/**/tests", "/tools", "/tutorials", -] # tracked by git, but we don't want to ship those files +] # tracked by git, but we don't want to ship those files [tool.hatch.version] raw-options = {version_scheme = "release-branch-semver"} @@ -320,37 +320,37 @@ exclude = ["__init__.py", "constants.py", "resources.py"] [tool.ruff.lint] ignore = [ - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package - "D413", # Missing blank line after last section + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D413", # Missing blank line after last section ] select = ["A", "B006", "D", "E", "F", "I", "UP", "UP031", "W"] [tool.ruff.lint.per-file-ignores] "examples/*/*.py" = [ - "D205", # 1 blank line required between summary line and description - "D400", # First line should end with a period + "D205", # 1 blank line required between summary line and description + "D400", # First line should end with a period ] "examples/preprocessing/eeg_bridging.py" = [ - "E501", # line too long + "E501", # line too long ] "mne/datasets/*/*.py" = [ - "D103", # Missing docstring in public function + "D103", # Missing docstring in public function ] "mne/decoding/tests/test_*.py" = [ - "E402", # Module level import not at top of file + "E402", # Module level import not at top of file ] "mne/utils/tests/test_docs.py" = [ - "D101", # Missing docstring in public class - "D410", # Missing blank line after section - "D411", # Missing blank line before section - "D414", # Section has no content + "D101", # Missing docstring in public class + "D410", # Missing blank line after section + "D411", # Missing blank line before section + "D414", # Section has no content ] "tutorials/*/*.py" = [ - "D400", # First line should end with a period + "D400", # First line should end with a period ] "tutorials/time-freq/10_spectrum_class.py" = [ - "E501", # line too long + "E501", # line too long ] [tool.ruff.lint.pydocstyle] @@ -366,6 +366,7 @@ ignore-decorators = [ [tool.tomlsort] all = true ignore_case = true +spaces_before_inline_comment = 2 trailing_comma_inline_array = true [tool.towncrier] From 94fc435de1786bd7f5e4545970111d9bdc8aa7f9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 10:18:34 -0400 Subject: [PATCH 44/55] Bump mamba-org/setup-micromamba from 1 to 2 in the actions group (#12887) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 571d4329831..81743789785 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -99,7 +99,7 @@ jobs: sed -i '/dipy/d' environment.yml sed -i 's/- mne$/- mne-base/' environment.yml if: matrix.os == 'ubuntu-latest' && startswith(matrix.kind, 'conda') && matrix.python == '3.12' - - uses: mamba-org/setup-micromamba@v1 + - uses: mamba-org/setup-micromamba@v2 with: environment-file: ${{ env.CONDA_ENV }} environment-name: mne From bddaad9510941c1cf4e81266569172e993b59d92 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 7 Oct 2024 11:28:33 -0400 Subject: [PATCH 45/55] BUG: Fix bugs with coreg (#12884) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- doc/changes/devel/12884.bugfix.rst | 1 + mne/event.py | 2 +- mne/gui/_coreg.py | 13 ++++++++++--- mne/gui/tests/test_coreg.py | 10 +++++++++- mne/preprocessing/realign.py | 4 ++-- mne/tests/test_event.py | 6 +++++- mne/viz/_3d.py | 15 +++++++++++---- mne/viz/backends/_qt.py | 2 ++ 8 files changed, 41 insertions(+), 12 deletions(-) create mode 100644 doc/changes/devel/12884.bugfix.rst diff --git a/doc/changes/devel/12884.bugfix.rst b/doc/changes/devel/12884.bugfix.rst new file mode 100644 index 00000000000..6c5beda7241 --- /dev/null +++ b/doc/changes/devel/12884.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :ref:`mne coreg` would always show MEG channels even if the "MEG Sensors" checkbox was disabled, by `Eric Larson`_. diff --git a/mne/event.py b/mne/event.py index 91b14c53e80..0c6b63f3396 100644 --- a/mne/event.py +++ b/mne/event.py @@ -519,7 +519,7 @@ def _find_events( else: logger.info( f"Trigger channel {ch_name} has a non-zero initial value of " - "{initial_value} (consider using initial_event=True to detect this " + f"{initial_value} (consider using initial_event=True to detect this " "event)" ) diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index de8e8984a7b..98e3fbfc0b3 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -1278,7 +1278,14 @@ def _add_channels(self): fnirs=self._defaults["sensor_opacity"], meg=0.25, ) - picks = pick_types(self._info, ref_meg=False, meg=True, eeg=True, fnirs=True) + picks = pick_types( + self._info, + ref_meg=False, + meg=True, + eeg=True, + fnirs=True, + exclude=(), + ) these_actors = _plot_sensors_3d( self._renderer, self._info, @@ -1295,7 +1302,7 @@ def _add_channels(self): nearest=self._nearest, **plot_types, ) - sens_actors = sum(these_actors.values(), list()) + sens_actors = sum((these_actors or {}).values(), list()) self._update_actor("sensors", sens_actors) def _add_head_surface(self): @@ -1760,7 +1767,7 @@ def _configure_dock(self): ) self._widgets["meg"] = self._renderer._dock_add_check_box( name="Show MEG sensors", - value=self._helmet, + value=self._meg_channels, callback=self._set_meg_channels, tooltip="Enable/Disable MEG sensors", layout=view_options_layout, diff --git a/mne/gui/tests/test_coreg.py b/mne/gui/tests/test_coreg.py index 409b5c6665c..9c0db7164c3 100644 --- a/mne/gui/tests/test_coreg.py +++ b/mne/gui/tests/test_coreg.py @@ -239,6 +239,9 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, renderer_interactive_pyv assert not coreg._helmet assert coreg._actors["helmet"] is None coreg._set_helmet(True) + assert coreg._eeg_channels + coreg._set_eeg_channels(False) + assert not coreg._eeg_channels assert coreg._helmet with catch_logging() as log: coreg._redraw(verbose="debug") @@ -251,11 +254,17 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, renderer_interactive_pyv log = log.getvalue() assert "Drawing helmet" in log assert not coreg._meg_channels + assert coreg._actors["helmet"] is not None + # TODO: Someday test our file dialogs like: + # coreg._widgets["save_trans"].widget.click() + assert len(coreg._actors["sensors"]) == 0 coreg._set_meg_channels(True) assert coreg._meg_channels with catch_logging() as log: coreg._redraw(verbose="debug") assert "Drawing meg sensors" in log.getvalue() + assert coreg._actors["helmet"] is not None + assert len(coreg._actors["sensors"]) == 306 assert coreg._orient_glyphs assert coreg._scale_by_distance assert coreg._mark_inside @@ -263,7 +272,6 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, renderer_interactive_pyv coreg._head_opacity, float(config.get("MNE_COREG_HEAD_OPACITY", "0.8")) ) assert coreg._hpi_coils - assert coreg._eeg_channels assert coreg._head_shape_points assert coreg._scale_mode == "None" assert coreg._icp_fid_match == "matched" diff --git a/mne/preprocessing/realign.py b/mne/preprocessing/realign.py index e9101b4b952..7f08937949e 100644 --- a/mne/preprocessing/realign.py +++ b/mne/preprocessing/realign.py @@ -11,7 +11,7 @@ @verbose -def realign_raw(raw, other, t_raw, t_other, verbose=None): +def realign_raw(raw, other, t_raw, t_other, *, verbose=None): """Realign two simultaneous recordings. Due to clock drift, recordings at a given same sample rate made by two @@ -111,7 +111,7 @@ def realign_raw(raw, other, t_raw, t_other, verbose=None): ) logger.info("Resampling other") sfreq_new = raw.info["sfreq"] * first_ord - other.load_data().resample(sfreq_new, verbose=True) + other.load_data().resample(sfreq_new) with other.info._unlock(): other.info["sfreq"] = raw.info["sfreq"] diff --git a/mne/tests/test_event.py b/mne/tests/test_event.py index 7e058912537..a540f1dce93 100644 --- a/mne/tests/test_event.py +++ b/mne/tests/test_event.py @@ -38,6 +38,7 @@ shift_time_events, ) from mne.io import RawArray, read_raw_fif +from mne.utils import catch_logging base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" fname = base_dir / "test-eve.fif" @@ -393,7 +394,10 @@ def test_find_events(): raw = RawArray(data, info, first_samp=7) data[0, :10] = 100 data[0, 30:40] = 200 - assert_array_equal(find_events(raw, "MYSTI"), [[37, 0, 200]]) + with catch_logging(True) as log: + assert_array_equal(find_events(raw, "MYSTI"), [[37, 0, 200]]) + log = log.getvalue() + assert "value of 100 (consider" in log assert_array_equal( find_events(raw, "MYSTI", initial_event=True), [[7, 0, 100], [37, 0, 200]] ) diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 7dbabfa1ef6..2fb94134830 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -1507,9 +1507,16 @@ def _plot_sensors_3d( elif ch_type in _MEG_CH_TYPES_SPLIT: ch_type = "meg" # only plot sensor locations if channels/original in selection - plot_sensors = (ch_type != "fnirs" or "channels" in fnirs) and ( - ch_type != "eeg" or "original" in eeg - ) + plot_sensors = True + if ch_type == "fnirs": + if not fnirs or "channels" not in fnirs: + plot_sensors = False + elif ch_type == "eeg": + if not eeg or "original" not in eeg: + plot_sensors = False + elif ch_type == "meg": + if not meg or "sensors" not in meg: + plot_sensors = False # plot sensors if isinstance(ch_coord, tuple): # is meg, plot coil ch_coord = dict(rr=ch_coord[0] * unit_scalar, tris=ch_coord[1]) @@ -1558,7 +1565,7 @@ def _plot_sensors_3d( assert isinstance(sensor_colors, dict) assert isinstance(sensor_scales, dict) for ch_type, sens_loc in locs.items(): - logger.debug(f"Drawing {ch_type} sensors") + logger.debug(f"Drawing {ch_type} sensors ({len(sens_loc)})") assert len(sens_loc) # should be guaranteed above colors = to_rgba_array(sensor_colors.get(ch_type, defaults[ch_type + "_color"])) scales = np.atleast_1d( diff --git a/mne/viz/backends/_qt.py b/mne/viz/backends/_qt.py index 17e458f98f7..11270aa7251 100644 --- a/mne/viz/backends/_qt.py +++ b/mne/viz/backends/_qt.py @@ -1213,6 +1213,8 @@ def _dock_add_file_button( ): layout = self._dock_layout if layout is None else layout weakself = weakref.ref(self) + if initial_directory is not None: + initial_directory = str(initial_directory) def callback(): self = weakself() From c0a98ed956c35ad42572253d2ce6f7ce0ed91a54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:42:21 -0400 Subject: [PATCH 46/55] [pre-commit.ci] pre-commit autoupdate (#12888) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06b13ffde72..951ae12d839 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.8 + rev: v0.6.9 hooks: - id: ruff name: ruff lint mne @@ -51,7 +51,7 @@ repos: # sorting - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: file-contents-sorter files: ^doc/changes/names.inc|^.mailmap From 53f258debb781aeee2c2e5cbf449f37822e3a898 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Tue, 8 Oct 2024 10:54:46 -0500 Subject: [PATCH 47/55] Website (#12885) --- doc/_static/style.css | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/doc/_static/style.css b/doc/_static/style.css index 8e53c6a49eb..70735d9c3fb 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -13,6 +13,20 @@ html[data-theme="light"] { + /* pydata-sphinx-theme overrides */ + /* ↓↓↓ use default "info" colors for "primary" */ + --pst-color-primary: #276be9; + --pst-color-primary-bg: #dce7fc; + /* ↓↓↓ use default "primary" colors for "info" */ + --pst-color-info: var(--pst-teal-500); + --pst-color-info-bg: var(--pst-teal-200); + /* ↓↓↓ use "warning" colors for "secondary" */ + --pst-color-secondary: var(--pst-color-warning); + --pst-color-secondary-bg: var(--pst-color-warning-bg); + /* ↓↓↓ make sure new primary (link) color propogates to links on code */ + --pst-color-inline-code-links: var(--pst-color-link); + /* ↓↓↓ make sure new secondary (hover) color propogates to hovering on table rows */ + --pst-color-table-row-hover-bg: var(--pst-color-secondary-bg); /* topbar logo links */ --mne-color-github: #000; --mne-color-discourse: #d0232b; @@ -21,8 +35,6 @@ html[data-theme="light"] { --copybtn-opacity: 0.75; /* card header bg color */ --mne-color-card-header: rgba(0, 0, 0, 0.05); - /* section headings */ - --mne-color-heading: #003e80; /* sphinx-gallery overrides */ --sg-download-a-background-color: var(--pst-color-primary); --sg-download-a-background-image: unset; @@ -33,6 +45,20 @@ html[data-theme="light"] { --sg-download-a-hover-box-shadow-2: none; } html[data-theme="dark"] { + /* pydata-sphinx-theme overrides */ + /* ↓↓↓ use default "info" colors for "primary" */ + --pst-color-primary: #79a3f2; + --pst-color-primary-bg: #06245d; + /* ↓↓↓ use default "primary" colors for "info" */ + --pst-color-info: var(--pst-teal-400); + --pst-color-info-bg: var(--pst-teal-800); + /* ↓↓↓ use "warning" colors for "secondary" */ + --pst-color-secondary: var(--pst-color-warning); + --pst-color-secondary-bg: var(--pst-color-warning-bg); + /* ↓↓↓ make sure new primary (link) color propogates to links on code */ + --pst-color-inline-code-links: var(--pst-color-link); + /* ↓↓↓ make sure new secondary (hover) color propogates to hovering on table rows */ + --pst-color-table-row-hover-bg: var(--pst-color-secondary-bg); /* topbar logo links */ --mne-color-github: rgb(240, 246, 252); /* from their logo SVG */ --mne-color-discourse: #FFF9AE; /* from their logo SVG */ @@ -41,8 +67,6 @@ html[data-theme="dark"] { --copybtn-opacity: 0.25; /* card header bg color */ --mne-color-card-header: rgba(255, 255, 255, 0.2); - /* section headings */ - --mne-color-heading: #b8cbe0; /* sphinx-gallery overrides */ --sg-download-a-background-color: var(--pst-color-primary); --sg-download-a-background-image: unset; @@ -52,9 +76,6 @@ html[data-theme="dark"] { --sg-download-a-hover-box-shadow-1: none; --sg-download-a-hover-box-shadow-2: none; } -h1, h2, h3, h4, h5, h6 { - color: var(--mne-color-heading); -} /* ************************************************************ Sphinx fixes */ From 06050bc72c637c297ff2e4fe0c7cba13243792e4 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 9 Oct 2024 22:54:58 +0200 Subject: [PATCH 48/55] Cast tuple of filenames to list to improve error handling (#12891) --- mne/io/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mne/io/base.py b/mne/io/base.py index b303072294a..79cbbe192ba 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -685,6 +685,8 @@ def filenames(self) -> tuple[Path | None, ...]: def filenames(self, value): """The filenames used, cast to list of paths.""" # noqa: D401 _validate_type(value, (list, tuple), "filenames") + if isinstance(value, tuple): + value = list(value) for k, elt in enumerate(value): if elt is not None: value[k] = _check_fname(elt, overwrite="read", must_exist=False) From 72faa3caa2239c6a4d973fc4cfd4a66a57ebcf00 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Thu, 10 Oct 2024 17:05:21 -0500 Subject: [PATCH 49/55] remove trailing slash from pybv base URL [ci skip] (#12892) --- doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index bc63c0af1c8..f66dc2af5b3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -166,7 +166,7 @@ "mne-gui-addons": ("https://mne.tools/mne-gui-addons", None), "picard": ("https://pierreablin.github.io/picard/", None), "eeglabio": ("https://eeglabio.readthedocs.io/en/latest", None), - "pybv": ("https://pybv.readthedocs.io/en/latest/", None), + "pybv": ("https://pybv.readthedocs.io/en/latest", None), } intersphinx_mapping.update( get_intersphinx_mapping( From 922a7801a0ca6af225c7b861fe6bd97b1518af3a Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 11 Oct 2024 10:43:02 -0500 Subject: [PATCH 50/55] Sync README dependencies with pyproject.toml (#12890) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 15 +++++ README.rst | 39 +++---------- pyproject.toml | 2 +- tools/hooks/sync_dependencies.py | 95 ++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 33 deletions(-) create mode 100755 tools/hooks/sync_dependencies.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 951ae12d839..b3f5191c438 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,6 +63,21 @@ repos: - id: toml-sort-fix files: pyproject.toml + # dependencies + - repo: local + hooks: + - id: dependency-sync + name: Sync dependency list between pyproject.toml and README.rst + language: python + entry: ./tools/hooks/sync_dependencies.py + files: pyproject.toml + additional_dependencies: ["mne"] + + +# these should *not* be run on CIs: +ci: + skip: [dependency-sync] # needs MNE to work, which exceeds the free tier space alloc. + # The following are too slow to run on local commits, so let's only run on CIs: # # - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/README.rst b/README.rst index 745e74849c3..50e0daaa52c 100644 --- a/README.rst +++ b/README.rst @@ -43,9 +43,6 @@ only, use pip_ in a terminal: $ pip install --upgrade mne -The current MNE-Python release requires Python 3.9 or higher. MNE-Python 0.17 -was the last release to support Python 2.7. - For more complete instructions, including our standalone installers and more advanced installation methods, please refer to the `installation guide`_. @@ -73,42 +70,20 @@ Dependencies The minimum required dependencies to run MNE-Python are: +.. ↓↓↓ BEGIN CORE DEPS LIST. DO NOT EDIT! HANDLED BY PRE-COMMIT HOOK ↓↓↓ + - `Python `__ ≥ 3.9 -- `NumPy `__ ≥ 1.24 -- `SciPy `__ ≥ 1.10 +- `NumPy `__ ≥ 1.23 +- `SciPy `__ ≥ 1.9 - `Matplotlib `__ ≥ 3.6 - `Pooch `__ ≥ 1.5 - `tqdm `__ - `Jinja2 `__ - `decorator `__ -- `lazy_loader `__ - -For full functionality, some functions require: - -- `scikit-learn `__ ≥ 1.2 -- `Joblib `__ ≥ 1.2 (for parallelization) -- `mne-qt-browser `__ ≥ 0.5 (for fast raw data visualization) -- `Qt `__ ≥ 5.15 via one of the following bindings (for fast raw data visualization and interactive 3D visualization): - - - `PySide6 `__ ≥ 6.0 - - `PyQt6 `__ ≥ 6.0 - - `PyQt5 `__ ≥ 5.15 - -- `Numba `__ ≥ 0.56.4 -- `NiBabel `__ ≥ 3.2.1 -- `OpenMEEG `__ ≥ 2.5.6 -- `pandas `__ ≥ 1.5.2 -- `Picard `__ ≥ 0.3 -- `CuPy `__ ≥ 9.0.0 (for NVIDIA CUDA acceleration) -- `DIPY `__ ≥ 1.4.0 -- `imageio `__ ≥ 2.8.0 -- `PyVista `__ ≥ 0.37 (for 3D visualization) -- `PyVistaQt `__ ≥ 0.9 (for 3D visualization) -- `mffpy `__ ≥ 0.5.7 -- `h5py `__ -- `h5io `__ -- `pymatreader `__ +- `lazy-loader `__ ≥ 0.3 +- `packaging `__ +.. ↑↑↑ END CORE DEPS LIST. DO NOT EDIT! HANDLED BY PRE-COMMIT HOOK ↑↑↑ Contributing ^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index d58793a6e45..9bdf8244b99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ scripts = {mne = "mne.commands.utils:main"} [project.optional-dependencies] # Leave this one here for backward-compat data = [] -dev = ["mne[test,doc]", "rcssmin"] +dev = ["mne[doc,test]", "rcssmin"] # Dependencies for building the documentation doc = [ "graphviz", diff --git a/tools/hooks/sync_dependencies.py b/tools/hooks/sync_dependencies.py new file mode 100755 index 00000000000..0878a5f56eb --- /dev/null +++ b/tools/hooks/sync_dependencies.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import re +from importlib.metadata import metadata +from pathlib import Path + +from mne.utils import _pl, warn + +README_PATH = Path(__file__).parents[2] / "README.rst" +BEGIN = ".. ↓↓↓ BEGIN CORE DEPS LIST. DO NOT EDIT! HANDLED BY PRE-COMMIT HOOK ↓↓↓" +END = ".. ↑↑↑ END CORE DEPS LIST. DO NOT EDIT! HANDLED BY PRE-COMMIT HOOK ↑↑↑" + +CORE_DEPS_URLS = { + "Python": "https://www.python.org", + "NumPy": "https://numpy.org", + "SciPy": "https://scipy.org", + "Matplotlib": "https://matplotlib.org", + "Pooch": "https://www.fatiando.org/pooch/latest/", + "tqdm": "https://tqdm.github.io", + "Jinja2": "https://palletsprojects.com/p/jinja/", + "decorator": "https://github.com/micheles/decorator", + "lazy-loader": "https://pypi.org/project/lazy_loader", + "packaging": "https://packaging.pypa.io/en/stable/", +} + + +def _prettify_pin(pin): + if pin is None: + return "" + pins = pin.split(",") + replacements = { + "<=": " ≤ ", + ">=": " ≥ ", + "<": " < ", + ">": " > ", + } + for old, new in replacements.items(): + pins = [p.replace(old, new) for p in pins] + pins = reversed(pins) + return ",".join(pins) + + +# get the dependency info +py_pin = metadata("mne").get("Requires-Python") +all_deps = metadata("mne").get_all("Requires-Dist") +core_deps = [f"python{py_pin}", *[dep for dep in all_deps if "extra ==" not in dep]] +pattern = re.compile(r"(?P[A-Za-z_\-\d]+)(?P[<>=]+.*)?") +core_deps_pins = { + dep["name"]: _prettify_pin(dep["pin"]) for dep in map(pattern.match, core_deps) +} +# don't show upper pin on NumPy (not important for users, just devs) +new_pin = core_deps_pins["numpy"].split(",") +new_pin.remove(" < 3") +core_deps_pins["numpy"] = new_pin[0] + +# make sure our URLs dict is minimal and complete +missing_urls = set(core_deps_pins) - {dep.lower() for dep in CORE_DEPS_URLS} +extra_urls = {dep.lower() for dep in CORE_DEPS_URLS} - set(core_deps_pins) +update_msg = ( + "please update `CORE_DEPS_URLS` mapping in `tools/hooks/sync_dependencies.py`." +) +if missing_urls: + _s = _pl(missing_urls) + raise RuntimeError( + f"Missing URL{_s} for package{_s} {', '.join(missing_urls)}; {update_msg}" + ) +if extra_urls: + _s = _pl(extra_urls) + warn(f"Superfluous URL{_s} for package{_s} {', '.join(extra_urls)}; {update_msg}") + +# construct the rST +core_deps_bullets = [ + f"- `{key} <{url}>`__{core_deps_pins[key.lower()]}" + for key, url in CORE_DEPS_URLS.items() +] +core_deps_rst = "\n" + "\n".join(core_deps_bullets) + "\n" + +# rewrite the README file +lines = README_PATH.read_text("utf-8").splitlines() +out_lines = list() +skip = False +for line in lines: + if line.strip() == BEGIN: + skip = True + out_lines.append(line) + out_lines.append(core_deps_rst) + if line.strip() == END: + skip = False + if not skip: + out_lines.append(line) +README_PATH.write_text("\n".join(out_lines) + "\n", encoding="utf-8") From f35aa5ae90616e97b0ce056a34dd98cdb7a9b184 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 18 Oct 2024 11:45:07 -0400 Subject: [PATCH 51/55] MAINT: Avoid problematic PySide6 (#12902) --- azure-pipelines.yml | 4 ++-- environment.yml | 24 ++++++++++++------------ pyproject.toml | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 3ced8ce46f9..1719dd95354 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -111,7 +111,7 @@ stages: - bash: | set -e python -m pip install --progress-bar off --upgrade pip - python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard qtpy nibabel sphinx-gallery PySide6 + python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1" python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e .[test] displayName: 'Install dependencies with pip' @@ -201,7 +201,7 @@ stages: displayName: 'PyQt6' - bash: | set -eo pipefail - python -m pip install PySide6 + python -m pip install "PySide6!=6.8.0,!=6.8.0.1" mne sys_info -pd mne sys_info -pd | grep "qtpy .* (PySide6=.*)$" PYTEST_QT_API=PySide6 pytest ${TEST_OPTIONS} diff --git a/environment.yml b/environment.yml index 64509188d47..9bf113c3ecf 100644 --- a/environment.yml +++ b/environment.yml @@ -2,7 +2,7 @@ name: mne channels: - conda-forge dependencies: - - python>=3.10 + - python >=3.10 - pip - numpy - scipy @@ -28,25 +28,25 @@ dependencies: - psutil - numexpr - imageio - - spyder-kernels>=1.10.0 - - imageio>=2.6.1 - - imageio-ffmpeg>=0.4.1 - - vtk>=9.2 + - spyder-kernels >=1.10.0 + - imageio >=2.6.1 + - imageio-ffmpeg >=0.4.1 + - vtk >=9.2 - traitlets - - pyvista>=0.32,!=0.35.2,!=0.38.0,!=0.38.1,!=0.38.2,!=0.38.3,!=0.38.4,!=0.38.5,!=0.38.6,!=0.42.0 - - pyvistaqt>=0.4 - - qdarkstyle!=3.2.2 + - pyvista >=0.32,!=0.35.2,!=0.38.0,!=0.38.1,!=0.38.2,!=0.38.3,!=0.38.4,!=0.38.5,!=0.38.6,!=0.42.0 + - pyvistaqt >=0.4 + - qdarkstyle !=3.2.2 - darkdetect - dipy - nibabel - - openmeeg>=2.5.5 + - openmeeg >=2.5.5 - nilearn - python-picard - qtpy - - pyside6 + - pyside6 !=6.8.0,!=6.8.0.1 - mne-base - seaborn-base - - mffpy>=0.5.7 + - mffpy >=0.5.7 - ipyevents - ipywidgets - ipympl @@ -59,7 +59,7 @@ dependencies: - mne-qt-browser - pymatreader - eeglabio - - edfio>=0.2.1 + - edfio >=0.2.1 - pybv - mamba - lazy_loader diff --git a/pyproject.toml b/pyproject.toml index 9bdf8244b99..45685a03544 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,7 +133,7 @@ full-no-qt = [ "xlrd", ] full-pyqt6 = ["mne[full]"] -full-pyside6 = ["mne[full-no-qt]", "PySide6!=6.7.0"] +full-pyside6 = ["mne[full-no-qt]", "PySide6!=6.7.0,!=6.8.0,!=6.8.0.1"] # Dependencies for MNE-Python functions that use HDF5 I/O hdf5 = ["h5io>=0.2.4", "pymatreader"] # Dependencies for running the test infrastructure From 56e522ba303fe01a25a2a0d5d7507639a83fe02c Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 18 Oct 2024 14:06:44 -0400 Subject: [PATCH 52/55] ENH: Improve report usability (#12901) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy --- doc/changes/devel/12901.bugfix.rst | 1 + doc/changes/devel/12901.newfeature.rst | 8 + mne/html_templates/report/html.html.jinja | 8 +- mne/html_templates/report/image.html.jinja | 26 +- mne/html_templates/report/section.html.jinja | 8 +- mne/html_templates/report/slider.html.jinja | 12 +- mne/report/report.py | 444 ++++++++++++++----- mne/report/tests/test_report.py | 79 +++- mne/utils/docs.py | 22 +- mne/viz/_3d.py | 60 ++- mne/viz/tests/test_3d.py | 6 +- tutorials/intro/70_report.py | 20 +- tutorials/preprocessing/59_head_positions.py | 2 +- 13 files changed, 510 insertions(+), 186 deletions(-) create mode 100644 doc/changes/devel/12901.bugfix.rst create mode 100644 doc/changes/devel/12901.newfeature.rst diff --git a/doc/changes/devel/12901.bugfix.rst b/doc/changes/devel/12901.bugfix.rst new file mode 100644 index 00000000000..d68f70f7141 --- /dev/null +++ b/doc/changes/devel/12901.bugfix.rst @@ -0,0 +1 @@ +:class:`mne.Report` HDF5 files are now written in ``mode='a'`` (append) to allow users to store other data in the HDF5 files, by `Eric Larson`_. diff --git a/doc/changes/devel/12901.newfeature.rst b/doc/changes/devel/12901.newfeature.rst new file mode 100644 index 00000000000..8d0137fce78 --- /dev/null +++ b/doc/changes/devel/12901.newfeature.rst @@ -0,0 +1,8 @@ +Improved reporting and plotting options: + +- :meth:`mne.Report.add_projs` can now plot with :func:`mne.viz.plot_projs_joint` rather than :func:`mne.viz.plot_projs_topomap` +- :class:`mne.Report` now has attributes ``img_max_width`` and ``img_max_res`` that can be used to control image scaling. +- :class:`mne.Report` now has an attribute ``collapse`` that allows collapsing sections and/or subsections by default. +- :func:`mne.viz.plot_head_positions` now has a ``totals=True`` option to show the total distance and angle of the head. + +Changes by `Eric Larson`_. diff --git a/mne/html_templates/report/html.html.jinja b/mne/html_templates/report/html.html.jinja index 62b9da07911..a9b4f881f12 100644 --- a/mne/html_templates/report/html.html.jinja +++ b/mne/html_templates/report/html.html.jinja @@ -1,7 +1,6 @@
-
- - -
+
{{ html | safe }}
diff --git a/mne/html_templates/report/image.html.jinja b/mne/html_templates/report/image.html.jinja index 06a6855ace5..41cf47e1395 100644 --- a/mne/html_templates/report/image.html.jinja +++ b/mne/html_templates/report/image.html.jinja @@ -1,17 +1,17 @@ {% extends "section.html.jinja" %} {% block html_content %} -
- {% if image_format == 'svg' %} -
- {{ img|safe }} -
- {% else %} - {{ title }} - {% endif %} +
+ {% if image_format == 'svg' %} +
+ {{ img|safe }} +
+ {% else %} + {{ title }} + {% endif %} - {% if caption is not none %} -
{{ caption }}
- {% endif %} -
+ {% if caption is not none %} +
{{ caption }}
+ {% endif %} +
{% endblock html_content %} diff --git a/mne/html_templates/report/section.html.jinja b/mne/html_templates/report/section.html.jinja index 584ff86dda9..baddf7dd8b6 100644 --- a/mne/html_templates/report/section.html.jinja +++ b/mne/html_templates/report/section.html.jinja @@ -1,7 +1,6 @@
-
- - -
+
{% block html_content %} {% for html in htmls %} diff --git a/mne/html_templates/report/slider.html.jinja b/mne/html_templates/report/slider.html.jinja index 58ee8a9f9fc..fab7f56472d 100644 --- a/mne/html_templates/report/slider.html.jinja +++ b/mne/html_templates/report/slider.html.jinja @@ -1,8 +1,7 @@
-
- -
+
- -
{{ channel_type }} -
{{ channel_type }} + + + {% if channels["bad"] %} + {% set channel_names_bad = channels["bad"] | map(attribute='name_html') | join(', ') %} + and + {% endif %} +