diff --git a/mne/io/pick.py b/mne/io/pick.py index 693586ba5f0..b74ac242161 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -886,9 +886,10 @@ def _pick_data_or_ica(info, exclude=()): def _picks_to_idx(info, picks, none='data', exclude='bads', allow_empty=False, - with_ref_meg=True): + with_ref_meg=True, return_kind=False): """Convert and check pick validity.""" from .meas_info import Info + picked_ch_type_or_generic = False # # None -> all, data, or data_or_ica (ndarray of int) # @@ -920,7 +921,11 @@ def _picks_to_idx(info, picks, none='data', exclude='bads', allow_empty=False, if picks.ndim != 1: raise ValueError('picks must be 1D, got %sD' % (picks.ndim,)) if picks.dtype.char in ('S', 'U'): - picks = _picks_str_to_idx(info, picks, exclude, with_ref_meg) + picks = _picks_str_to_idx(info, picks, exclude, with_ref_meg, + return_kind) + if return_kind: + picked_ch_type_or_generic = picks[1] + picks = picks[0] if picks.dtype.kind not in ['i', 'u']: raise TypeError('picks must be a list of int or list of str, got ' 'a data type of %s' % (picks.dtype,)) @@ -939,10 +944,12 @@ def _picks_to_idx(info, picks, none='data', exclude='bads', allow_empty=False, raise ValueError('All picks must be < n_channels (%d), got %r' % (n_chan, orig_picks)) picks %= n_chan # ensure positive + if return_kind: + return picks, picked_ch_type_or_generic return picks -def _picks_str_to_idx(info, picks, exclude, with_ref_meg): +def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind=False): """Turn a list of str into ndarray of int.""" # special case for _picks_to_idx w/no info: shouldn't really happen if isinstance(info, int): @@ -1027,6 +1034,9 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg): 'picks for these') else: picks = np.array(all_picks[np.where(any_found)[0][0]]) + if return_kind: + picked_ch_type_or_generic = not len(picks_name) + return picks, picked_ch_type_or_generic return picks diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 8bf9cfb7cac..4a55964dc80 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -18,7 +18,6 @@ from ..utils import (verbose, get_config, set_config, logger, warn, _pl, fill_doc) -from ..utils.check import _is_numeric from ..io.pick import (pick_types, channel_type, _get_channel_types, _picks_to_idx, _DATA_CH_TYPES_SPLIT, _DATA_CH_TYPES_ORDER_DEFAULT) @@ -169,14 +168,12 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None, if combine is not None: ts_args["show_sensors"] = False - if picks is None or not _is_numeric(picks): - picks = _picks_to_idx(epochs.info, picks) - if group_by is None: - logger.info("No picks and no groupby, showing the first five " - "channels ...") - picks = picks[:5] # take 5 picks to prevent spawning many figs - else: - picks = _picks_to_idx(epochs.info, picks) + picks, picked_ch_type_or_generic = _picks_to_idx(epochs.info, picks, + return_kind=True) + if picked_ch_type_or_generic and group_by is None: + logger.info("No picks and no groupby, showing the first five " + "channels ...") + picks = picks[:5] # take 5 picks to prevent spawning too many figs if "invert_y" in ts_args: raise NotImplementedError("'invert_y' found in 'ts_args'. " diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 22cbbd3fdb4..133731a508f 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -137,6 +137,7 @@ def test_plot_epochs_image(): """Test plotting of epochs image.""" epochs = _get_epochs() epochs.plot_image(picks=[1, 2]) + epochs.plot_image(picks='mag') overlay_times = [0.1] epochs.plot_image(picks=[1], order=[0], overlay_times=overlay_times, vmin=0.01, title="test"