diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 68907b1fb87..9ff728b9285 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -76,6 +76,7 @@ Bugs API changes ~~~~~~~~~~~ +- In meth:`mne.Evoked.plot`, the default value of the ``spatial_colors`` parameter has been changed to ``'auto'``, which will use spatial colors if channel locations are available (:gh:`11201` by :newcontrib:`Hüseyin Orkun Elmas` and `Daniel McCloy`_) - Starting with this release we now follow the Python convention of using ``FutureWarning`` instead of ``DeprecationWarning`` to signal user-facing changes to our API (:gh:`11120` by `Daniel McCloy`_) - The ``bands`` parameter of :meth:`mne.Epochs.plot_psd_topomap` now accepts :class:`dict` input; legacy :class:`tuple` input is supported, but discouraged for new code (:gh:`11050` by `Daniel McCloy`_) - The :func:`mne.head_to_mri` new function parameter ``kind`` default will change from ``'ras'`` to ``'mri'`` (:gh:`11185` by `Eric Larson`_) diff --git a/doc/changes/names.inc b/doc/changes/names.inc index bd28ffd383b..ed309767b30 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -174,6 +174,8 @@ .. _Hubert Banville: https://github.com/hubertjb +.. _Hüseyin Orkun Elmas: https://github.com/HuseyinOrkun + .. _Ilias Machairas: https://github.com/JungleHippo .. _Ivana Kojcic: https://github.com/ikojcic diff --git a/mne/evoked.py b/mne/evoked.py index a369a1b909e..48c12f946b7 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -374,7 +374,7 @@ def ch_names(self): def plot(self, picks=None, exclude='bads', unit=True, show=True, ylim=None, xlim='tight', proj=False, hline=None, units=None, scalings=None, titles=None, axes=None, gfp=False, window_title=None, - spatial_colors=False, zorder='unsorted', selectable=True, + spatial_colors='auto', zorder='unsorted', selectable=True, noise_cov=None, time_unit='s', sphere=None, *, highlight=None, verbose=None): return plot_evoked( diff --git a/mne/viz/_proj.py b/mne/viz/_proj.py index ade298c40b9..ee06ec51e3b 100644 --- a/mne/viz/_proj.py +++ b/mne/viz/_proj.py @@ -157,7 +157,8 @@ def plot_projs_joint(projs, evoked, picks_trace=None, *, topomap_kwargs=None, ch_traces = evoked.data[picks_trace] ch_traces -= np.mean(ch_traces, axis=1, keepdims=True) ch_traces /= np.abs(ch_traces).max() - _plot_evoked(this_evoked, picks='all', axes=[tr_ax], **pe_kwargs) + _plot_evoked(this_evoked, picks='all', axes=[tr_ax], **pe_kwargs, + spatial_colors=False) for line in tr_ax.lines: line.set(lw=0.5, zorder=3) for t in list(tr_ax.texts): diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 310f0317cb6..6ad8f6f68f0 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -193,6 +193,18 @@ def _plot_legend(pos, colors, axis, bads, outlines, loc, size=30): _draw_outlines(ax, outlines) +def _check_spatial_colors(info, picks, spatial_colors): + """Use spatial colors if channel locations exist.""" + # NB: this assumes `picks`` has already been through _picks_to_idx() + # and it reflects *just the picks for the current subplot* + if spatial_colors == 'auto': + if len(picks) == 1: + spatial_colors = False + else: + spatial_colors = _check_ch_locs(info) + return spatial_colors + + def _plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True, ylim=None, proj=False, xlim='tight', hline=None, units=None, scalings=None, titles=None, axes=None, @@ -217,7 +229,7 @@ def _plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True, If True, draw at the end. """ import matplotlib.pyplot as plt - + _check_option('spatial_colors', spatial_colors, [True, False, 'auto']) # For evoked.plot_image ... # First input checks for group_by and axes if any of them is not None. # Either both must be dicts, or neither. @@ -251,7 +263,8 @@ def _plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True, mask_style=mask_style, mask_cmap=mask_cmap, mask_alpha=mask_alpha, time_unit=time_unit, show_names=show_names, - sphere=sphere, draw=False) + sphere=sphere, draw=False, + spatial_colors=spatial_colors) if remove_xlabels and not _is_last_row(ax): ax.set_xticklabels([]) ax.set_xlabel("") @@ -451,19 +464,22 @@ def _plot_lines(data, info, picks, fig, axes, spatial_colors, unit, units, if not gfp_only: chs = [info['chs'][i] for i in idx] locs3d = np.array([ch['loc'][:3] for ch in chs]) - if (spatial_colors is True and + # _plot_psd can pass spatial_colors=color (e.g., "black") so + # we need to use "is True" here + _spat_col = _check_spatial_colors(info, idx, spatial_colors) + if (_spat_col is True and not _check_ch_locs(info=info, picks=idx)): warn('Channel locations not available. Disabling spatial ' 'colors.') - spatial_colors = selectable = False - if spatial_colors is True and len(idx) != 1: + _spat_col = selectable = False + if _spat_col is True and len(idx) != 1: x, y, z = locs3d.T colors = _rgb(x, y, z) _handle_spatial_colors(colors, info, idx, this_type, psd, ax, sphere) else: - if isinstance(spatial_colors, (tuple, str)): - col = [spatial_colors] + if isinstance(_spat_col, (tuple, str)): + col = [_spat_col] else: col = ['k'] colors = col * len(idx) @@ -488,7 +504,7 @@ def _plot_lines(data, info, picks, fig, axes, spatial_colors, unit, units, for ch_idx, z in enumerate(z_ord): line_list.append( ax.plot(times, D[ch_idx], picker=True, - zorder=z + 1 if spatial_colors is True else 1, + zorder=z + 1 if _spat_col else 1, color=colors[ch_idx], alpha=line_alpha, linewidth=0.5)[0]) line_list[-1].set_pickradius(3.) @@ -736,11 +752,14 @@ def plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True, Plot GFP for EEG instead of RMS. Label RMS traces correctly as such. window_title : str | None The title to put at the top of the figure. - spatial_colors : bool + spatial_colors : bool | 'auto' If True, the lines are color coded by mapping physical sensor coordinates into color values. Spatially similar channels will have similar colors. Bad channels will be dotted. If False, the good - channels are plotted black and bad channels red. Defaults to False. + channels are plotted black and bad channels red. If ``'auto'``, uses + True if channel locations are present, and False if channel locations + are missing or if the data contains only a single channel. Defaults to + ``'auto'``. zorder : str | callable Which channels to put in the front or back. Only matters if ``spatial_colors`` is used. @@ -1243,7 +1262,7 @@ def whitened_gfp(x, rank=None): if not has_sss: evokeds_white[0].plot(unit=False, axes=axes_evoked, hline=[-1.96, 1.96], show=False, - time_unit=time_unit) + time_unit=time_unit, spatial_colors=False) else: for ((ch_type, picks), ax) in zip(picks_list, axes_evoked): ax.plot(times, evokeds_white[0].data[picks].T, color='k', diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 7aee10cdfe7..33d1f6cb530 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -1006,12 +1006,12 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): fig.suptitle(title) axes = axes.flatten() if isinstance(axes, np.ndarray) else axes - evoked.plot(axes=axes, show=False, time_unit='s') + evoked.plot(axes=axes, show=False, time_unit='s', spatial_colors=False) for ax in fig.axes: for line in ax.get_lines(): line.set_color('r') fig.canvas.draw() - evoked_cln.plot(axes=axes, show=False, time_unit='s') + evoked_cln.plot(axes=axes, show=False, time_unit='s', spatial_colors=False) tight_layout(fig=fig) fig.subplots_adjust(top=0.90) diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index df995c71044..2bd46d45f40 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -17,6 +17,7 @@ import matplotlib.pyplot as plt from matplotlib import gridspec from matplotlib.collections import PolyCollection +from mpl_toolkits.axes_grid1.parasite_axes import HostAxes # spatial_colors import mne from mne import (read_events, Epochs, read_cov, compute_covariance, @@ -174,7 +175,8 @@ def test_plot_evoked(): [(0, 0.1), (0.1, 0.2)] ]: fig = evoked.plot(time_unit='s', highlight=highlight) - for ax in fig.get_axes(): + regular_axes = [ax for ax in fig.axes if not isinstance(ax, HostAxes)] + for ax in regular_axes: highlighted_areas = [child for child in ax.get_children() if isinstance(child, PolyCollection)] assert len(highlighted_areas) == len(np.atleast_2d(highlight)) @@ -196,14 +198,19 @@ def test_constrained_layout(): assert fig.get_constrained_layout() evoked = mne.read_evokeds(evoked_fname)[0] evoked.pick(evoked.ch_names[:2]) - evoked.plot(axes=ax) # smoke test that it does not break things + + # smoke test that it does not break things + evoked.plot(axes=ax) assert fig.get_constrained_layout() plt.close('all') def _get_amplitudes(fig): - amplitudes = [line.get_ydata() for ax in fig.axes + # ignore the spatial_colors parasite axes + regular_axes = [ax for ax in fig.axes if not isinstance(ax, HostAxes)] + amplitudes = [line.get_ydata() for ax in regular_axes for line in ax.get_lines()] + # this will exclude hlines, which are lists not arrays amplitudes = np.array( [line for line in amplitudes if isinstance(line, np.ndarray)]) return amplitudes