Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ def score(self, X, y):
"""
return self.model.score(X, y)

# Needed for sklearn 1.3+
@property
def classes_(self):
"""The classes (pass-through to model)."""
return self.model.classes_


def _set_cv(cv, estimator=None, X=None, y=None):
"""Set the default CV depending on whether clf is classifier/regressor."""
Expand Down
12 changes: 10 additions & 2 deletions mne/io/array/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def test_array_raw():

# plotting
raw2.plot()
(raw2.compute_psd(tmax=2.0, n_fft=1024).plot(average=True, spatial_colors=False))
# TODO remove context handler after 1.4 release.
with pytest.warns(RuntimeWarning, match="bad channels will be shown"):
(
raw2.compute_psd(tmax=2.0, n_fft=1024).plot(
average=True, spatial_colors=False
)
)
plt.close("all")

# epoching
Expand Down Expand Up @@ -183,5 +189,7 @@ def test_array_raw():
raw = RawArray(data, info)
raw.set_montage(montage)
spectrum = raw.compute_psd()
spectrum.plot(average=False) # looking for nonexistent layout
# TODO remove context handler after 1.4 release.
with pytest.warns(RuntimeWarning, match="bad channels will be shown"):
spectrum.plot(average=False) # looking for nonexistent layout
spectrum.plot_topo()
32 changes: 19 additions & 13 deletions mne/report/tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,14 @@ def test_report_raw_psd_and_date(tmp_path):
raw_fname_new = tmp_path / "temp_raw.fif"
raw.save(raw_fname_new)
report = Report(raw_psd=True)
report.parse_folder(
data_path=tmp_path,
render_bem=False,
on_error="raise",
raw_butterfly=False,
)
# TODO: remove context handler after 1.4 release.
with pytest.warns(RuntimeWarning, match="bad channels will be shown"):
report.parse_folder(
data_path=tmp_path,
render_bem=False,
on_error="raise",
raw_butterfly=False,
)
assert isinstance(report.html, list)
assert "PSD" in "".join(report.html)
assert "Unknown" not in "".join(report.html)
Expand Down Expand Up @@ -851,7 +853,8 @@ def test_manual_report_2d(tmp_path, invisible_fig):
ica_ecg_scores = ica_eog_scores = np.array([3, 0])
ica_ecg_evoked = ica_eog_evoked = epochs_without_metadata.average()

r.add_raw(raw=raw, title="my raw data", tags=("raw",), psd=True, projs=False)
with pytest.warns(RuntimeWarning, match="bad channels will be shown"):
r.add_raw(raw=raw, title="my raw data", tags=("raw",), psd=True, projs=False)
r.add_raw(raw=raw, title="my raw data 2", psd=False, projs=False, butterfly=1)
r.add_events(events=events_fname, title="my events", sfreq=raw.info["sfreq"])
r.add_epochs(
Expand All @@ -861,12 +864,15 @@ def test_manual_report_2d(tmp_path, invisible_fig):
psd=False,
projs=False,
)
r.add_epochs(
epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False
)
r.add_epochs(
epochs=epochs_without_metadata, title="my epochs 2", psd=True, projs=False
)
# TODO: remove next two context handlers after 1.4 release.
with pytest.warns(RuntimeWarning, match="bad channels will be shown"):
r.add_epochs(
epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False
)
with pytest.warns(RuntimeWarning, match="bad channels will be shown"):
r.add_epochs(
epochs=epochs_without_metadata, title="my epochs 2", psd=True, projs=False
)
assert "Metadata" not in r.html[-1]

# Try with metadata
Expand Down
23 changes: 21 additions & 2 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def plot(
alpha=None,
spatial_colors=True,
sphere=None,
exclude="bads",
exclude=None,
axes=None,
show=True,
):
Expand All @@ -583,6 +583,11 @@ def plot(
Parameters
----------
%(picks_good_data_noref)s

.. versionchanged:: 1.5
In version 1.5, the default behavior will change so that all
:term:`data channels` (not just "good" data channels) are shown
by default.
average : bool
Whether to average across channels before plotting. If ``True``,
interactive plotting of scalp topography is disabled, and
Expand Down Expand Up @@ -615,6 +620,10 @@ def plot(
%(spatial_colors_psd)s
%(sphere_topomap_auto)s
%(exclude_spectrum_plot)s

.. versionchanged:: 1.5
In version 1.5, the default behavior will change from
``exclude='bads'`` to ``exclude=()``.
%(axes_spectrum_plot_topomap)s
%(show)s

Expand All @@ -640,7 +649,17 @@ def plot(
else: # amplitude is boolean
estimate = "amplitude" if amplitude else "power"
# split picks by channel type
picks = _picks_to_idx(self.info, picks, "data", with_ref_meg=False)
if picks is None or exclude is None:
warn(
"in version 1.5, the default behavior of Spectrum.plot() will "
"change so that bad channels will be shown by default. To keep the "
"old default behavior (and silence this warning), explicitly pass "
"`picks='data', exclude='bads'`."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh we should have had FutureWarning on this line. To end users it won't matter I guess so no need to backport I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dang.

exclude = "bads"
picks = _picks_to_idx(
self.info, picks, "data", exclude=exclude, with_ref_meg=False
)
(picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type(
self, picks, units, scalings, titles
)
Expand Down
4 changes: 2 additions & 2 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,9 +1311,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

_exclude_spectrum = """\
exclude : list of str | 'bads'
exclude : list of str | 'bads' | None
Channel names to exclude{}. If ``'bads'``, channels
in ``spectrum.info['bads']`` are excluded; pass an empty list to
in ``spectrum.info['bads']`` are excluded; pass an empty list or tuple to
plot all channels (including "bad" channels, if any).
"""

Expand Down
2 changes: 1 addition & 1 deletion mne/utils/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ def _split_psd_kwargs(*, plot_fun=None, kwargs=None):
# (otherwise integer picks could be wrong, `None` will be handled wrong
# for `misc` data, etc)
if plot_fun is Spectrum.plot:
plot_kwargs["picks"] = "all" # TODO: this should be the default
plot_kwargs["picks"] = "all" # TODO: this will be the default in v1.5
return kwargs, plot_kwargs
10 changes: 6 additions & 4 deletions mne/viz/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,10 @@ def test_plot_drop_log(epochs_unloaded):
def test_plot_psd_epochs(epochs):
"""Test plotting epochs psd (+topomap)."""
spectrum = epochs.compute_psd()
spectrum.plot(average=True, spatial_colors=False)
spectrum.plot(average=False, spatial_colors=True)
spectrum.plot(average=False, spatial_colors=False)
old_defaults = dict(picks="data", exclude="bads")
spectrum.plot(average=True, spatial_colors=False, **old_defaults)
spectrum.plot(average=False, spatial_colors=True, **old_defaults)
spectrum.plot(average=False, spatial_colors=False, **old_defaults)
# test plot_psd_topomap errors
with pytest.raises(RuntimeError, match="No frequencies in band"):
spectrum.plot_topomap(bands=dict(foo=(0, 0.01)))
Expand Down Expand Up @@ -457,13 +458,14 @@ def test_plot_psd_epochs_ctf(raw_ctf):
"""Test plotting CTF epochs psd (+topomap)."""
evts = make_fixed_length_events(raw_ctf)
epochs = Epochs(raw_ctf, evts, preload=True)
old_defaults = dict(picks="data", exclude="bads")
# EEG060 is flat in this dataset
with pytest.warns(UserWarning, match="for channel EEG060"):
spectrum = epochs.compute_psd()
for dB in [True, False]:
spectrum.plot(dB=dB)
spectrum.drop_channels(["EEG060"])
spectrum.plot(spatial_colors=False, average=False)
spectrum.plot(spatial_colors=False, average=False, **old_defaults)
with pytest.raises(RuntimeError, match="No frequencies in band"):
spectrum.plot_topomap(bands=[(0, 0.01, "foo")])
spectrum.plot_topomap()
Expand Down
26 changes: 15 additions & 11 deletions mne/viz/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,29 +791,33 @@ def test_plot_raw_filtered(filtorder, raw, browser_backend):
def test_plot_raw_psd(raw, raw_orig):
"""Test plotting of raw psds."""
raw_unchanged = raw.copy()
# normal mode
spectrum = raw.compute_psd()
fig = spectrum.plot(average=False)
# deprecation change handler
old_defaults = dict(picks="data", exclude="bads")
with pytest.warns(RuntimeWarning, match="bad channels will be shown"):
fig = spectrum.plot(average=False)
# normal mode
fig = spectrum.plot(average=False, **old_defaults)
fig.canvas.callbacks.process(
"resize_event", backend_bases.ResizeEvent("resize_event", fig.canvas)
)
# specific mode
picks = pick_types(spectrum.info, meg="mag", eeg=False)[:4]
spectrum.plot(picks=picks, ci="range", spatial_colors=True)
raw.compute_psd(tmax=20.0).plot(color="yellow", dB=False, alpha=0.4)
spectrum.plot(picks=picks, ci="range", spatial_colors=True, exclude="bads")
raw.compute_psd(tmax=20.0).plot(color="yellow", dB=False, alpha=0.4, **old_defaults)
plt.close("all")
# one axes supplied
ax = plt.axes()
spectrum.plot(picks=picks, axes=ax, average=True)
spectrum.plot(picks=picks, axes=ax, average=True, exclude="bads")
plt.close("all")
# two axes supplied
_, axs = plt.subplots(2)
spectrum.plot(axes=axs, average=True)
spectrum.plot(axes=axs, average=True, **old_defaults)
plt.close("all")
# need 2, got 1
ax = plt.axes()
with pytest.raises(ValueError, match="of length 2.*the length is 1"):
spectrum.plot(axes=ax, average=True)
spectrum.plot(axes=ax, average=True, **old_defaults)
plt.close("all")
# topo psd
ax = plt.subplot()
Expand Down Expand Up @@ -859,8 +863,8 @@ def test_plot_raw_psd(raw, raw_orig):
raw = raw_orig.crop(0, 1)
picks = pick_types(raw.info, meg=True)
spectrum = raw.compute_psd(picks=picks)
spectrum.plot(average=False)
spectrum.plot(average=True)
spectrum.plot(average=False, **old_defaults)
spectrum.plot(average=True, **old_defaults)
plt.close("all")
raw.set_channel_types(
{
Expand All @@ -871,7 +875,7 @@ def test_plot_raw_psd(raw, raw_orig):
},
verbose="error",
)
fig = raw.compute_psd().plot()
fig = raw.compute_psd().plot(**old_defaults)
assert len(fig.axes) == 10
plt.close("all")

Expand All @@ -882,7 +886,7 @@ def test_plot_raw_psd(raw, raw_orig):
raw = RawArray(data, info)
picks = pick_types(raw.info, misc=True)
spectrum = raw.compute_psd(picks=picks, n_fft=n_fft)
spectrum.plot(spatial_colors=False, picks=picks)
spectrum.plot(spatial_colors=False, picks=picks, exclude="bads")
plt.close("all")


Expand Down