diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 348ee2ee0f7..247c6f89f2d 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -214,6 +214,12 @@ def score(self, X, y): """ return self.model.score(X, y) + # Needed for sklearn 1.3+ + @property + def classes_(self): + """The classes (pass-through to model).""" + return self.model.classes_ + def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" diff --git a/mne/io/array/tests/test_array.py b/mne/io/array/tests/test_array.py index 1a96b9e4488..dffba1da152 100644 --- a/mne/io/array/tests/test_array.py +++ b/mne/io/array/tests/test_array.py @@ -150,7 +150,13 @@ def test_array_raw(): # plotting raw2.plot() - (raw2.compute_psd(tmax=2.0, n_fft=1024).plot(average=True, spatial_colors=False)) + # TODO remove context handler after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + ( + raw2.compute_psd(tmax=2.0, n_fft=1024).plot( + average=True, spatial_colors=False + ) + ) plt.close("all") # epoching @@ -183,5 +189,7 @@ def test_array_raw(): raw = RawArray(data, info) raw.set_montage(montage) spectrum = raw.compute_psd() - spectrum.plot(average=False) # looking for nonexistent layout + # TODO remove context handler after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + spectrum.plot(average=False) # looking for nonexistent layout spectrum.plot_topo() diff --git a/mne/report/tests/test_report.py b/mne/report/tests/test_report.py index c7357b53c41..24b0d4bbe9a 100644 --- a/mne/report/tests/test_report.py +++ b/mne/report/tests/test_report.py @@ -347,12 +347,14 @@ def test_report_raw_psd_and_date(tmp_path): raw_fname_new = tmp_path / "temp_raw.fif" raw.save(raw_fname_new) report = Report(raw_psd=True) - report.parse_folder( - data_path=tmp_path, - render_bem=False, - on_error="raise", - raw_butterfly=False, - ) + # TODO: remove context handler after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + report.parse_folder( + data_path=tmp_path, + render_bem=False, + on_error="raise", + raw_butterfly=False, + ) assert isinstance(report.html, list) assert "PSD" in "".join(report.html) assert "Unknown" not in "".join(report.html) @@ -851,7 +853,8 @@ def test_manual_report_2d(tmp_path, invisible_fig): ica_ecg_scores = ica_eog_scores = np.array([3, 0]) ica_ecg_evoked = ica_eog_evoked = epochs_without_metadata.average() - r.add_raw(raw=raw, title="my raw data", tags=("raw",), psd=True, projs=False) + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + r.add_raw(raw=raw, title="my raw data", tags=("raw",), psd=True, projs=False) r.add_raw(raw=raw, title="my raw data 2", psd=False, projs=False, butterfly=1) r.add_events(events=events_fname, title="my events", sfreq=raw.info["sfreq"]) r.add_epochs( @@ -861,12 +864,15 @@ def test_manual_report_2d(tmp_path, invisible_fig): psd=False, projs=False, ) - r.add_epochs( - epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False - ) - r.add_epochs( - epochs=epochs_without_metadata, title="my epochs 2", psd=True, projs=False - ) + # TODO: remove next two context handlers after 1.4 release. + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + r.add_epochs( + epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False + ) + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + r.add_epochs( + epochs=epochs_without_metadata, title="my epochs 2", psd=True, projs=False + ) assert "Metadata" not in r.html[-1] # Try with metadata diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 06daec43495..9939e94a6b7 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -574,7 +574,7 @@ def plot( alpha=None, spatial_colors=True, sphere=None, - exclude="bads", + exclude=None, axes=None, show=True, ): @@ -583,6 +583,11 @@ def plot( Parameters ---------- %(picks_good_data_noref)s + + .. versionchanged:: 1.5 + In version 1.5, the default behavior will change so that all + :term:`data channels` (not just "good" data channels) are shown + by default. average : bool Whether to average across channels before plotting. If ``True``, interactive plotting of scalp topography is disabled, and @@ -615,6 +620,10 @@ def plot( %(spatial_colors_psd)s %(sphere_topomap_auto)s %(exclude_spectrum_plot)s + + .. versionchanged:: 1.5 + In version 1.5, the default behavior will change from + ``exclude='bads'`` to ``exclude=()``. %(axes_spectrum_plot_topomap)s %(show)s @@ -640,7 +649,17 @@ def plot( else: # amplitude is boolean estimate = "amplitude" if amplitude else "power" # split picks by channel type - picks = _picks_to_idx(self.info, picks, "data", with_ref_meg=False) + if picks is None or exclude is None: + warn( + "in version 1.5, the default behavior of Spectrum.plot() will " + "change so that bad channels will be shown by default. To keep the " + "old default behavior (and silence this warning), explicitly pass " + "`picks='data', exclude='bads'`." + ) + exclude = "bads" + picks = _picks_to_idx( + self.info, picks, "data", exclude=exclude, with_ref_meg=False + ) (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( self, picks, units, scalings, titles ) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index cba3339b9ed..ed62d04a178 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1311,9 +1311,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _exclude_spectrum = """\ -exclude : list of str | 'bads' +exclude : list of str | 'bads' | None Channel names to exclude{}. If ``'bads'``, channels - in ``spectrum.info['bads']`` are excluded; pass an empty list to + in ``spectrum.info['bads']`` are excluded; pass an empty list or tuple to plot all channels (including "bad" channels, if any). """ diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index d0f0d6c9c89..9857f333252 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -63,5 +63,5 @@ def _split_psd_kwargs(*, plot_fun=None, kwargs=None): # (otherwise integer picks could be wrong, `None` will be handled wrong # for `misc` data, etc) if plot_fun is Spectrum.plot: - plot_kwargs["picks"] = "all" # TODO: this should be the default + plot_kwargs["picks"] = "all" # TODO: this will be the default in v1.5 return kwargs, plot_kwargs diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 5dabfa7fafa..c87755e37b7 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -361,9 +361,10 @@ def test_plot_drop_log(epochs_unloaded): def test_plot_psd_epochs(epochs): """Test plotting epochs psd (+topomap).""" spectrum = epochs.compute_psd() - spectrum.plot(average=True, spatial_colors=False) - spectrum.plot(average=False, spatial_colors=True) - spectrum.plot(average=False, spatial_colors=False) + old_defaults = dict(picks="data", exclude="bads") + spectrum.plot(average=True, spatial_colors=False, **old_defaults) + spectrum.plot(average=False, spatial_colors=True, **old_defaults) + spectrum.plot(average=False, spatial_colors=False, **old_defaults) # test plot_psd_topomap errors with pytest.raises(RuntimeError, match="No frequencies in band"): spectrum.plot_topomap(bands=dict(foo=(0, 0.01))) @@ -457,13 +458,14 @@ def test_plot_psd_epochs_ctf(raw_ctf): """Test plotting CTF epochs psd (+topomap).""" evts = make_fixed_length_events(raw_ctf) epochs = Epochs(raw_ctf, evts, preload=True) + old_defaults = dict(picks="data", exclude="bads") # EEG060 is flat in this dataset with pytest.warns(UserWarning, match="for channel EEG060"): spectrum = epochs.compute_psd() for dB in [True, False]: spectrum.plot(dB=dB) spectrum.drop_channels(["EEG060"]) - spectrum.plot(spatial_colors=False, average=False) + spectrum.plot(spatial_colors=False, average=False, **old_defaults) with pytest.raises(RuntimeError, match="No frequencies in band"): spectrum.plot_topomap(bands=[(0, 0.01, "foo")]) spectrum.plot_topomap() diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 1bc13505a7e..5533b5870c1 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -791,29 +791,33 @@ def test_plot_raw_filtered(filtorder, raw, browser_backend): def test_plot_raw_psd(raw, raw_orig): """Test plotting of raw psds.""" raw_unchanged = raw.copy() - # normal mode spectrum = raw.compute_psd() - fig = spectrum.plot(average=False) + # deprecation change handler + old_defaults = dict(picks="data", exclude="bads") + with pytest.warns(RuntimeWarning, match="bad channels will be shown"): + fig = spectrum.plot(average=False) + # normal mode + fig = spectrum.plot(average=False, **old_defaults) fig.canvas.callbacks.process( "resize_event", backend_bases.ResizeEvent("resize_event", fig.canvas) ) # specific mode picks = pick_types(spectrum.info, meg="mag", eeg=False)[:4] - spectrum.plot(picks=picks, ci="range", spatial_colors=True) - raw.compute_psd(tmax=20.0).plot(color="yellow", dB=False, alpha=0.4) + spectrum.plot(picks=picks, ci="range", spatial_colors=True, exclude="bads") + raw.compute_psd(tmax=20.0).plot(color="yellow", dB=False, alpha=0.4, **old_defaults) plt.close("all") # one axes supplied ax = plt.axes() - spectrum.plot(picks=picks, axes=ax, average=True) + spectrum.plot(picks=picks, axes=ax, average=True, exclude="bads") plt.close("all") # two axes supplied _, axs = plt.subplots(2) - spectrum.plot(axes=axs, average=True) + spectrum.plot(axes=axs, average=True, **old_defaults) plt.close("all") # need 2, got 1 ax = plt.axes() with pytest.raises(ValueError, match="of length 2.*the length is 1"): - spectrum.plot(axes=ax, average=True) + spectrum.plot(axes=ax, average=True, **old_defaults) plt.close("all") # topo psd ax = plt.subplot() @@ -859,8 +863,8 @@ def test_plot_raw_psd(raw, raw_orig): raw = raw_orig.crop(0, 1) picks = pick_types(raw.info, meg=True) spectrum = raw.compute_psd(picks=picks) - spectrum.plot(average=False) - spectrum.plot(average=True) + spectrum.plot(average=False, **old_defaults) + spectrum.plot(average=True, **old_defaults) plt.close("all") raw.set_channel_types( { @@ -871,7 +875,7 @@ def test_plot_raw_psd(raw, raw_orig): }, verbose="error", ) - fig = raw.compute_psd().plot() + fig = raw.compute_psd().plot(**old_defaults) assert len(fig.axes) == 10 plt.close("all") @@ -882,7 +886,7 @@ def test_plot_raw_psd(raw, raw_orig): raw = RawArray(data, info) picks = pick_types(raw.info, misc=True) spectrum = raw.compute_psd(picks=picks, n_fft=n_fft) - spectrum.plot(spatial_colors=False, picks=picks) + spectrum.plot(spatial_colors=False, picks=picks, exclude="bads") plt.close("all")