From 53b5fa5acd77f98a829a0cec98a7dddae9111604 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 5 May 2023 09:32:57 -0500 Subject: [PATCH 1/8] arg default deprecation --- mne/time_frequency/spectrum.py | 14 +++++++++++++- mne/utils/docs.py | 4 ++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 06daec43495..16aeb31f97e 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -574,7 +574,7 @@ def plot( alpha=None, spatial_colors=True, sphere=None, - exclude="bads", + exclude=None, axes=None, show=True, ): @@ -615,6 +615,10 @@ def plot( %(spatial_colors_psd)s %(sphere_topomap_auto)s %(exclude_spectrum_plot)s + + .. versionchanged:: 1.5 + In version 1.5, the default behavior will change from + ``exclude='bads'`` to ``exclude=()``. %(axes_spectrum_plot_topomap)s %(show)s @@ -640,6 +644,14 @@ def plot( else: # amplitude is boolean estimate = "amplitude" if amplitude else "power" # split picks by channel type + if exclude is None: + warn( + "in version 1.5, the default behavior of Spectrum.plot() will " + "change so that bad channels will be shown by default. To keep the " + "old default behavior (and silence this warning), explicitly pass " + "exclude='bads'." + ) + exclude = "bads" picks = _picks_to_idx(self.info, picks, "data", with_ref_meg=False) (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( self, picks, units, scalings, titles diff --git a/mne/utils/docs.py b/mne/utils/docs.py index cba3339b9ed..ed62d04a178 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1311,9 +1311,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _exclude_spectrum = """\ -exclude : list of str | 'bads' +exclude : list of str | 'bads' | None Channel names to exclude{}. If ``'bads'``, channels - in ``spectrum.info['bads']`` are excluded; pass an empty list to + in ``spectrum.info['bads']`` are excluded; pass an empty list or tuple to plot all channels (including "bad" channels, if any). """ From 1d99bdfd39c7bd275f8fb3db9491b96da1cfe754 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 5 May 2023 09:33:20 -0500 Subject: [PATCH 2/8] FIX: actually *use* value of exclude --- mne/time_frequency/spectrum.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 16aeb31f97e..157b48996a6 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -652,7 +652,9 @@ def plot( "exclude='bads'." ) exclude = "bads" - picks = _picks_to_idx(self.info, picks, "data", with_ref_meg=False) + picks = _picks_to_idx( + self.info, picks, "data", exclude=exclude, with_ref_meg=False + ) (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( self, picks, units, scalings, titles ) From 1140abe0e289534947cdda4cb7e4d5943f926b3e Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 5 May 2023 09:40:39 -0500 Subject: [PATCH 3/8] change picks too --- mne/time_frequency/spectrum.py | 9 +++++++-- mne/utils/spectrum.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 157b48996a6..9939e94a6b7 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -583,6 +583,11 @@ def plot( Parameters ---------- %(picks_good_data_noref)s + + .. versionchanged:: 1.5 + In version 1.5, the default behavior will change so that all + :term:`data channels` (not just "good" data channels) are shown + by default. average : bool Whether to average across channels before plotting. If ``True``, interactive plotting of scalp topography is disabled, and @@ -644,12 +649,12 @@ def plot( else: # amplitude is boolean estimate = "amplitude" if amplitude else "power" # split picks by channel type - if exclude is None: + if picks is None or exclude is None: warn( "in version 1.5, the default behavior of Spectrum.plot() will " "change so that bad channels will be shown by default. To keep the " "old default behavior (and silence this warning), explicitly pass " - "exclude='bads'." + "`picks='data', exclude='bads'`." ) exclude = "bads" picks = _picks_to_idx( diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index d0f0d6c9c89..9857f333252 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -63,5 +63,5 @@ def _split_psd_kwargs(*, plot_fun=None, kwargs=None): # (otherwise integer picks could be wrong, `None` will be handled wrong # for `misc` data, etc) if plot_fun is Spectrum.plot: - plot_kwargs["picks"] = "all" # TODO: this should be the default + plot_kwargs["picks"] = "all" # TODO: this will be the default in v1.5 return kwargs, plot_kwargs From 7db5d0f18133fe723d87c26a4418f4072efeecda Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 5 May 2023 10:01:41 -0500 Subject: [PATCH 4/8] fix tests --- mne/viz/tests/test_epochs.py | 10 ++++++---- mne/viz/tests/test_raw.py | 26 +++++++++++++++----------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 5dabfa7fafa..c87755e37b7 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -361,9 +361,10 @@ def test_plot_drop_log(epochs_unloaded): def test_plot_psd_epochs(epochs): """Test plotting epochs psd (+topomap).""" spectrum = epochs.compute_psd() - spectrum.plot(average=True, spatial_colors=False) - spectrum.plot(average=False, spatial_colors=True) - spectrum.plot(average=False, spatial_colors=False) + old_defaults = dict(picks="data", exclude="bads") + spectrum.plot(average=True, spatial_colors=False, **old_defaults) + spectrum.plot(average=False, spatial_colors=True, **old_defaults) + spectrum.plot(average=False, spatial_colors=False, **old_defaults) # test plot_psd_topomap errors with pytest.raises(RuntimeError, match="No frequencies in band"): spectrum.plot_topomap(bands=dict(foo=(0, 0.01))) @@ -457,13 +458,14 @@ def test_plot_psd_epochs_ctf(raw_ctf): """Test plotting CTF epochs psd (+topomap).""" evts = make_fixed_length_events(raw_ctf) epochs = Epochs(raw_ctf, evts, preload=True) + old_defaults = dict(picks="data", exclude="bads") # EEG060 is flat in this dataset with pytest.warns(UserWarning, match="for channel EEG060"): spectrum = epochs.compute_psd() for dB in [True, False]: spectrum.plot(dB=dB) spectrum.drop_channels(["EEG060"]) - spectrum.plot(spatial_colors=False, average=False) + spectrum.plot(spatial_colors=False, average=False, **old_defaults) with pytest.raises(RuntimeError, match="No frequencies in band"): spectrum.plot_topomap(bands=[(0, 0.01, "foo")]) spectrum.plot_topomap() diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 1bc13505a7e..5533b5870c1 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -791,29 +791,33 @@ def test_plot_raw_filtered(filtorder, raw, browser_backend): def test_plot_raw_psd(raw, raw_orig): """Test plotting of raw psds.""" raw_unchanged = raw.copy() - # normal mode spectrum = raw.compute_psd() - fig = spectrum.plot(average=False) + # deprecation change handler + old_defaults = dict(picks="data", exclude="bads") + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + fig = spectrum.plot(average=False) + # normal mode + fig = spectrum.plot(average=False, **old_defaults) fig.canvas.callbacks.process( "resize_event", backend_bases.ResizeEvent("resize_event", fig.canvas) ) # specific mode picks = pick_types(spectrum.info, meg="mag", eeg=False)[:4] - spectrum.plot(picks=picks, ci="range", spatial_colors=True) - raw.compute_psd(tmax=20.0).plot(color="yellow", dB=False, alpha=0.4) + spectrum.plot(picks=picks, ci="range", spatial_colors=True, exclude="bads") + raw.compute_psd(tmax=20.0).plot(color="yellow", dB=False, alpha=0.4, **old_defaults) plt.close("all") # one axes supplied ax = plt.axes() - spectrum.plot(picks=picks, axes=ax, average=True) + spectrum.plot(picks=picks, axes=ax, average=True, exclude="bads") plt.close("all") # two axes supplied _, axs = plt.subplots(2) - spectrum.plot(axes=axs, average=True) + spectrum.plot(axes=axs, average=True, **old_defaults) plt.close("all") # need 2, got 1 ax = plt.axes() with pytest.raises(ValueError, match="of length 2.*the length is 1"): - spectrum.plot(axes=ax, average=True) + spectrum.plot(axes=ax, average=True, **old_defaults) plt.close("all") # topo psd ax = plt.subplot() @@ -859,8 +863,8 @@ def test_plot_raw_psd(raw, raw_orig): raw = raw_orig.crop(0, 1) picks = pick_types(raw.info, meg=True) spectrum = raw.compute_psd(picks=picks) - spectrum.plot(average=False) - spectrum.plot(average=True) + spectrum.plot(average=False, **old_defaults) + spectrum.plot(average=True, **old_defaults) plt.close("all") raw.set_channel_types( { @@ -871,7 +875,7 @@ def test_plot_raw_psd(raw, raw_orig): }, verbose="error", ) - fig = raw.compute_psd().plot() + fig = raw.compute_psd().plot(**old_defaults) assert len(fig.axes) == 10 plt.close("all") @@ -882,7 +886,7 @@ def test_plot_raw_psd(raw, raw_orig): raw = RawArray(data, info) picks = pick_types(raw.info, misc=True) spectrum = raw.compute_psd(picks=picks, n_fft=n_fft) - spectrum.plot(spatial_colors=False, picks=picks) + spectrum.plot(spatial_colors=False, picks=picks, exclude="bads") plt.close("all") From 8557b599f7266973bacd6ff15fa3bf0ede2c7242 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 5 May 2023 14:11:08 -0500 Subject: [PATCH 5/8] fix report and IO array tests --- mne/io/array/tests/test_array.py | 8 +++++++- mne/report/tests/test_report.py | 32 +++++++++++++++++++------------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mne/io/array/tests/test_array.py b/mne/io/array/tests/test_array.py index 1a96b9e4488..63160ba639b 100644 --- a/mne/io/array/tests/test_array.py +++ b/mne/io/array/tests/test_array.py @@ -150,7 +150,13 @@ def test_array_raw(): # plotting raw2.plot() - (raw2.compute_psd(tmax=2.0, n_fft=1024).plot(average=True, spatial_colors=False)) + # TODO remove context handler after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + ( + raw2.compute_psd(tmax=2.0, n_fft=1024).plot( + average=True, spatial_colors=False + ) + ) plt.close("all") # epoching diff --git a/mne/report/tests/test_report.py b/mne/report/tests/test_report.py index c7357b53c41..24b0d4bbe9a 100644 --- a/mne/report/tests/test_report.py +++ b/mne/report/tests/test_report.py @@ -347,12 +347,14 @@ def test_report_raw_psd_and_date(tmp_path): raw_fname_new = tmp_path / "temp_raw.fif" raw.save(raw_fname_new) report = Report(raw_psd=True) - report.parse_folder( - data_path=tmp_path, - render_bem=False, - on_error="raise", - raw_butterfly=False, - ) + # TODO: remove context handler after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + report.parse_folder( + data_path=tmp_path, + render_bem=False, + on_error="raise", + raw_butterfly=False, + ) assert isinstance(report.html, list) assert "PSD" in "".join(report.html) assert "Unknown" not in "".join(report.html) @@ -851,7 +853,8 @@ def test_manual_report_2d(tmp_path, invisible_fig): ica_ecg_scores = ica_eog_scores = np.array([3, 0]) ica_ecg_evoked = ica_eog_evoked = epochs_without_metadata.average() - r.add_raw(raw=raw, title="my raw data", tags=("raw",), psd=True, projs=False) + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + r.add_raw(raw=raw, title="my raw data", tags=("raw",), psd=True, projs=False) r.add_raw(raw=raw, title="my raw data 2", psd=False, projs=False, butterfly=1) r.add_events(events=events_fname, title="my events", sfreq=raw.info["sfreq"]) r.add_epochs( @@ -861,12 +864,15 @@ def test_manual_report_2d(tmp_path, invisible_fig): psd=False, projs=False, ) - r.add_epochs( - epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False - ) - r.add_epochs( - epochs=epochs_without_metadata, title="my epochs 2", psd=True, projs=False - ) + # TODO: remove next two context handlers after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + r.add_epochs( + epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False + ) + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + r.add_epochs( + epochs=epochs_without_metadata, title="my epochs 2", psd=True, projs=False + ) assert "Metadata" not in r.html[-1] # Try with metadata From d5acbf4d0927c2116ed9933eb24696b444d4723d Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 5 May 2023 14:28:40 -0500 Subject: [PATCH 6/8] missed one --- mne/io/array/tests/test_array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mne/io/array/tests/test_array.py b/mne/io/array/tests/test_array.py index 63160ba639b..dffba1da152 100644 --- a/mne/io/array/tests/test_array.py +++ b/mne/io/array/tests/test_array.py @@ -189,5 +189,7 @@ def test_array_raw(): raw = RawArray(data, info) raw.set_montage(montage) spectrum = raw.compute_psd() - spectrum.plot(average=False) # looking for nonexistent layout + # TODO remove context handler after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + spectrum.plot(average=False) # looking for nonexistent layout spectrum.plot_topo() From d01d4a65461d3937b40b0539fa15d487098a295d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 5 May 2023 15:31:43 -0400 Subject: [PATCH 7/8] FIX: Conditional --- mne/decoding/tests/test_base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index c7773a217d4..e372bd194f9 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -15,7 +15,7 @@ from mne import create_info, EpochsArray from mne.fixes import is_regressor, is_classifier -from mne.utils import requires_sklearn +from mne.utils import requires_sklearn, check_version from mne.decoding.base import ( _get_inverse_funcs, LinearModel, @@ -285,6 +285,10 @@ def test_get_coef_multiclass(n_features, n_targets): @requires_sklearn +@pytest.mark.xfail( + when=check_version("sklearn", "1.3"), + reason="https://github.com/scikit-learn/scikit-learn/issues/26336", +) @pytest.mark.parametrize( "n_classes, n_channels, n_times", [ From fabe1e711a69c88436c1c3cec89d5c541e8bd385 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 5 May 2023 15:43:12 -0400 Subject: [PATCH 8/8] FIX: Correct fix --- mne/decoding/base.py | 6 ++++++ mne/decoding/tests/test_base.py | 6 +----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 348ee2ee0f7..247c6f89f2d 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -214,6 +214,12 @@ def score(self, X, y): """ return self.model.score(X, y) + # Needed for sklearn 1.3+ + @property + def classes_(self): + """The classes (pass-through to model).""" + return self.model.classes_ + def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index e372bd194f9..c7773a217d4 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -15,7 +15,7 @@ from mne import create_info, EpochsArray from mne.fixes import is_regressor, is_classifier -from mne.utils import requires_sklearn, check_version +from mne.utils import requires_sklearn from mne.decoding.base import ( _get_inverse_funcs, LinearModel, @@ -285,10 +285,6 @@ def test_get_coef_multiclass(n_features, n_targets): @requires_sklearn -@pytest.mark.xfail( - when=check_version("sklearn", "1.3"), - reason="https://github.com/scikit-learn/scikit-learn/issues/26336", -) @pytest.mark.parametrize( "n_classes, n_channels, n_times", [