diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 03da865c30d..91aadf8a311 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -69,6 +69,8 @@ Changelog - Speed up raw data reading without preload in :func:`mne.io.read_raw_nirx` by `Eric Larson`_ +- Speed up :meth:`mne.Epochs.copy` and :meth:`mne.Epochs.__getitem__` by avoiding copying immutable attributes by `Eric Larson`_ + - Support for saving movies of source time courses (STCs) with ``brain.save_movie`` method and from graphical user interface by `Guillaume Favelier`_ - Add ``mri`` and ``show_orientation`` arguments to :func:`mne.viz.plot_bem` by `Eric Larson`_ @@ -174,6 +176,8 @@ Bug - Fix bug with :class:`mne.Epochs` when metadata was not subselected properly when ``event_repeated='drop'`` by `Eric Larson`_ +- Fix bug with :class:`mne.Epochs` where ``epochs.drop_log`` was a list of list of str rather than an immutable tuple of tuple of str (not meant to be changed by the user) by `Eric Larson`_ + - Fix bug with :class:`mne.Report` where the BEM section could not be toggled by `Eric Larson`_ - Fix bug when using :meth:`mne.Epochs.crop` to exclude the baseline period would break :func:`mne.Epochs.save` / :func:`mne.read_epochs` round-trip by `Eric Larson`_ diff --git a/mne/epochs.py b/mne/epochs.py index 887a86224f4..1f2ded6a539 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -17,10 +17,8 @@ import operator import os.path as op import warnings -from distutils.version import LooseVersion import numpy as np -import scipy from .io.write import (start_file, start_block, end_file, end_block, write_int, write_float, write_float_matrix, @@ -271,6 +269,7 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, # Else, we have duplicates. Triage ... _check_option('event_repeated', event_repeated, ['error', 'drop', 'merge']) + drop_log = list(drop_log) if event_repeated == 'error': raise RuntimeError('Event time samples were not unique. Consider ' 'setting the `event_repeated` parameter."') @@ -282,7 +281,7 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, new_selection = selection[u_ev_idxs] drop_ev_idxs = np.setdiff1d(selection, new_selection) for idx in drop_ev_idxs: - drop_log[idx].append('DROP DUPLICATE') + drop_log[idx] = drop_log[idx] + ('DROP DUPLICATE',) selection = new_selection elif event_repeated == 'merge': logger.info('Multiple event values for single event times found. ' @@ -291,8 +290,9 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, _merge_events(events, event_id, selection) drop_ev_idxs = np.setdiff1d(selection, new_selection) for idx in drop_ev_idxs: - drop_log[idx].append('MERGE DUPLICATE') + drop_log[idx] = drop_log[idx] + ('MERGE DUPLICATE',) selection = new_selection + drop_log = tuple(drop_log) # Remove obsolete kv-pairs from event_id after handling keys = new_events[:, 1:].flatten() @@ -354,9 +354,9 @@ class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, ShiftTimeMixin, selection : iterable | None Iterable of indices of selected epochs. If ``None``, will be automatically generated, corresponding to all non-zero events. - drop_log : list | None - List of lists of strings indicating which epochs have been marked to be - ignored. + drop_log : tuple | None + Tuple of tuple of strings indicating which epochs have been marked to + be ignored. filename : str | None The filename (if the epochs are read from disk). metadata : instance of pandas.DataFrame | None @@ -432,9 +432,10 @@ def __init__(self, info, data, events, event_id=None, tmin=-0.2, tmax=0.5, % (selected.shape, selection.shape)) self.selection = selection if drop_log is None: - self.drop_log = [list() if k in self.selection else ['IGNORED'] - for k in range(max(len(self.events), - max(self.selection) + 1))] + self.drop_log = tuple( + () if k in self.selection else ('IGNORED',) + for k in range(max(len(self.events), + max(self.selection) + 1))) else: self.drop_log = drop_log @@ -559,6 +560,9 @@ def _check_consistency(self): assert len(self.drop_log) >= len(self.events) assert hasattr(self, '_times_readonly') assert not self.times.flags['WRITEABLE'] + assert isinstance(self.drop_log, tuple) + assert all(isinstance(log, tuple) for log in self.drop_log) + assert all(isinstance(s, str) for log in self.drop_log for s in log) def load_data(self): """Load the data if not already preloaded. @@ -760,13 +764,13 @@ def _reject_setup(self, reject, flat): def _is_good_epoch(self, data, verbose=None): """Determine if epoch is good.""" if isinstance(data, str): - return False, [data] + return False, (data,) if data is None: - return False, ['NO_DATA'] + return False, ('NO_DATA',) n_times = len(self.times) if data.shape[1] < n_times: # epoch is too short ie at the end of the data - return False, ['TOO_SHORT'] + return False, ('TOO_SHORT',) if self.reject is None and self.flat is None: return True, None else: @@ -1353,6 +1357,7 @@ def _get_data(self, out=True, picks=None, item=None, verbose=None): # e.g., when calling drop_bad w/new params good_idx = [] n_out = 0 + drop_log = list(self.drop_log) assert n_events == len(self.selection) for idx, sel in enumerate(self.selection): if self.preload: # from memory @@ -1368,9 +1373,11 @@ def _get_data(self, out=True, picks=None, item=None, verbose=None): epoch = self._project_epoch(epoch_noproj) epoch_out = epoch_noproj if self._do_delayed_proj else epoch - is_good, offending_reason = self._is_good_epoch(epoch) + is_good, bad_tuple = self._is_good_epoch(epoch) if not is_good: - self.drop_log[sel] += offending_reason + assert isinstance(bad_tuple, tuple) + assert all(isinstance(x, str) for x in bad_tuple) + drop_log[sel] = drop_log[sel] + bad_tuple continue good_idx.append(idx) @@ -1383,6 +1390,8 @@ def _get_data(self, out=True, picks=None, item=None, verbose=None): dtype=epoch_out.dtype, order='C') data[n_out] = epoch_out n_out += 1 + self.drop_log = tuple(drop_log) + del drop_log self._bad_dropped = True logger.info("%d bad epochs dropped" % (n_events - len(good_idx))) @@ -1543,13 +1552,21 @@ def copy(self): epochs : instance of Epochs A copy of the object. """ - raw = self._raw - del self._raw - new = deepcopy(self) - self._raw = raw - new._raw = raw - new._set_times(new.times) # sets RO - return new + return deepcopy(self) + + def __deepcopy__(self, memodict): + """Make a deepcopy.""" + cls = self.__class__ + result = cls.__new__(cls) + for k, v in self.__dict__.items(): + # drop_log is immutable and _raw is private (and problematic to + # deepcopy) + if k in ('drop_log', '_raw', '_times_readonly'): + memodict[id(v)] = v + else: + v = deepcopy(v, memodict) + result.__dict__[k] = v + return result @verbose def save(self, fname, split_size='2GB', fmt='single', overwrite=False, @@ -1902,8 +1919,10 @@ def _drop_log_stats(drop_log, ignore=('IGNORED',)): perc : float Total percentage of epochs dropped. """ - if not isinstance(drop_log, list) or not isinstance(drop_log[0], list): - raise ValueError('drop_log must be a list of lists') + if not isinstance(drop_log, tuple) or \ + not all(isinstance(d, tuple) for d in drop_log) or \ + not all(isinstance(s, str) for d in drop_log for s in d): + raise TypeError('drop_log must be a tuple of tuple of str') perc = 100 * np.mean([len(d) > 0 for d in drop_log if not any(r in ignore for r in d)]) return perc @@ -2036,17 +2055,21 @@ class Epochs(BaseEpochs): has been dropped, this attribute would be np.array([0, 2, 3]). preload : bool Indicates whether epochs are in memory. - drop_log : list of list - A list of the same length as the event array used to initialize the + drop_log : tuple of tuple + A tuple of the same length as the event array used to initialize the Epochs object. If the i-th original event is still part of the - selection, drop_log[i] will be an empty list; otherwise it will be - a list of the reasons the event is not longer in the selection, e.g.: - - 'IGNORED' if it isn't part of the current subset defined by the user; - 'NO_DATA' or 'TOO_SHORT' if epoch didn't contain enough data; - names of channels that exceeded the amplitude threshold; - 'EQUALIZED_COUNTS' (see equalize_event_counts); - or 'USER' for user-defined reasons (see drop method). + selection, drop_log[i] will be an empty tuple; otherwise it will be + a tuple of the reasons the event is not longer in the selection, e.g.: + + - 'IGNORED' + If it isn't part of the current subset defined by the user + - 'NO_DATA' or 'TOO_SHORT' + If epoch didn't contain enough data names of channels that exceeded + the amplitude threshold + - 'EQUALIZED_COUNTS' + See :meth:`~mne.Epochs.equalize_event_counts` + - 'USER' + For user-defined reasons (see :meth:`~mne.Epochs.drop`). filename : str The filename of the object. times : ndarray @@ -2380,13 +2403,6 @@ def _get_drop_indices(event_times, method): return indices -def _fix_fill(fill): - """Fix bug on old scipy.""" - if LooseVersion(scipy.__version__) < LooseVersion('0.12'): - fill = fill[:, np.newaxis] - return fill - - def _minimize_time_diff(t_shorter, t_longer): """Find a boolean mask to minimize timing differences.""" from scipy.interpolate import interp1d @@ -2413,7 +2429,7 @@ def _minimize_time_diff(t_shorter, t_longer): x2 = np.arange(len(t_longer) - ii - 1) t_keeps = np.array([t_longer[km] for km in keep_mask]) longer_interp = interp1d(x2, t_keeps, axis=1, - fill_value=_fix_fill(t_keeps[:, -1]), + fill_value=t_keeps[:, -1], **kwargs) d1 = longer_interp(x1) - t_shorter d2 = shorter_interp(x2) - t_keeps @@ -2430,7 +2446,7 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, If full_report=True, it will give True/False as well as a list of all offending channels. """ - bad_list = list() + bad_tuple = tuple() has_printed = False checkable = np.ones(len(ch_names), dtype=bool) checkable[np.array([c in ignore_chs @@ -2448,23 +2464,23 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, checkable_idx))[0] if len(idx_deltas) > 0: - ch_name = [ch_names[idx[i]] for i in idx_deltas] + bad_names = [ch_names[idx[i]] for i in idx_deltas] if (not has_printed): logger.info(' Rejecting %s epoch based on %s : ' - '%s' % (t, name, ch_name)) + '%s' % (t, name, bad_names)) has_printed = True if not full_report: return False else: - bad_list.extend(ch_name) + bad_tuple += tuple(bad_names) if not full_report: return True else: - if bad_list == []: + if bad_tuple == (): return True, None else: - return False, bad_list + return False, bad_tuple def _read_one_epoch_file(f, tree, preload): @@ -2541,7 +2557,7 @@ def _read_one_epoch_file(f, tree, preload): selection = np.array(tag.data) elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG: tag = read_tag(fid, pos) - drop_log = json.loads(tag.data) + drop_log = tuple(tuple(x) for x in json.loads(tag.data)) elif kind == FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT: tag = read_tag(fid, pos) reject_params = json.loads(tag.data) @@ -2604,7 +2620,7 @@ def _read_one_epoch_file(f, tree, preload): if selection is None: selection = np.arange(len(events)) if drop_log is None: - drop_log = [[] for _ in range(len(events))] + drop_log = ((),) * len(events) return (info, data, data_tag, events, event_id, metadata, tmin, tmax, baseline, selection, drop_log, epoch_shape, cals, reject_params, @@ -2744,12 +2760,13 @@ def __init__(self, fname, proj=True, preload=True, assert len(drop_log) % len(fnames) == 0 step = len(drop_log) // len(fnames) offsets = np.arange(step, len(drop_log) + 1, step) + drop_log = list(drop_log) for i1, i2 in zip(offsets[:-1], offsets[1:]): other_log = drop_log[i1:i2] for k, (a, b) in enumerate(zip(drop_log, other_log)): - if a == ['IGNORED'] and b != ['IGNORED']: + if a == ('IGNORED',) and b != ('IGNORED',): drop_log[k] = b - drop_log = drop_log[:step] + drop_log = tuple(drop_log[:step]) # call BaseEpochs constructor super(EpochsFIF, self).__init__( @@ -2949,7 +2966,7 @@ def _concatenate_epochs(epochs_list, with_data=True, add_offset=True): baseline, tmin, tmax = out.baseline, out.tmin, out.tmax info = deepcopy(out.info) verbose = out.verbose - drop_log = deepcopy(out.drop_log) + drop_log = out.drop_log event_id = deepcopy(out.event_id) selection = out.selection # offset is the last epoch + tmax + 10 second @@ -2985,7 +3002,7 @@ def _concatenate_epochs(epochs_list, with_data=True, add_offset=True): int((10 + tmax) * epochs.info['sfreq'])) events.append(evs) selection = np.concatenate((selection, epochs.selection)) - drop_log.extend(epochs.drop_log) + drop_log = drop_log + epochs.drop_log event_id.update(epochs.event_id) metadata.append(epochs.metadata) events = np.concatenate(events, axis=0) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index f1e5163907e..bc0540b4aa7 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -71,10 +71,10 @@ def test_event_repeated(): events = np.array([[10, 0, 1], [10, 0, 2]]) epochs = mne.Epochs(raw, events, event_repeated='drop') - assert epochs.drop_log == [[], ['DROP DUPLICATE']] + assert epochs.drop_log == ((), ('DROP DUPLICATE',)) assert_array_equal(epochs.selection, [0]) epochs = mne.Epochs(raw, events, event_repeated='merge') - assert epochs.drop_log == [[], ['MERGE DUPLICATE']] + assert epochs.drop_log == ((), ('MERGE DUPLICATE',)) assert_array_equal(epochs.selection, [0]) @@ -87,34 +87,34 @@ def test_handle_event_repeated(): [5, 0, 2], [5, 0, 1], [5, 0, 3], [7, 0, 1]]) SELECTION = np.arange(len(EVENTS)) - DROP_LOG = [list() for _ in range(len(EVENTS))] + DROP_LOG = ((),) * len(EVENTS) with pytest.raises(RuntimeError, match='Event time samples were not uniq'): _handle_event_repeated(EVENTS, EVENT_ID, event_repeated='error', selection=SELECTION, - drop_log=deepcopy(DROP_LOG)) + drop_log=DROP_LOG) events, event_id, selection, drop_log = _handle_event_repeated( - EVENTS, EVENT_ID, 'drop', SELECTION, deepcopy(DROP_LOG)) + EVENTS, EVENT_ID, 'drop', SELECTION, DROP_LOG) assert_array_equal(events, [[0, 0, 1], [3, 0, 2], [5, 0, 2], [7, 0, 1]]) assert_array_equal(events, EVENTS[selection]) unselection = np.setdiff1d(SELECTION, selection) - assert all(drop_log[k] == ['DROP DUPLICATE'] for k in unselection) + assert all(drop_log[k] == ('DROP DUPLICATE',) for k in unselection) assert event_id == {'aud': 1, 'vis': 2} events, event_id, selection, drop_log = _handle_event_repeated( - EVENTS, EVENT_ID, 'merge', SELECTION, deepcopy(DROP_LOG)) + EVENTS, EVENT_ID, 'merge', SELECTION, DROP_LOG) assert_array_equal(events[0][-1], events[1][-1]) assert_array_equal(events, [[0, 0, 4], [3, 0, 4], [5, 0, 5], [7, 0, 1]]) assert_array_equal(events[:, :2], EVENTS[selection][:, :2]) unselection = np.setdiff1d(SELECTION, selection) - assert all(drop_log[k] == ['MERGE DUPLICATE'] for k in unselection) + assert all(drop_log[k] == ('MERGE DUPLICATE',) for k in unselection) assert set(event_id.keys()) == set(['aud', 'aud/vis', 'aud/foo/vis']) assert event_id['aud/vis'] == 4 # Test early return with no changes: no error for wrong event_repeated arg fine_events = np.array([[0, 0, 1], [1, 0, 2]]) events, event_id, selection, drop_log = _handle_event_repeated( - fine_events, EVENT_ID, 'no', [0, 2], deepcopy(DROP_LOG)) + fine_events, EVENT_ID, 'no', [0, 2], DROP_LOG) assert event_id == EVENT_ID assert_array_equal(selection, [0, 2]) assert drop_log == DROP_LOG @@ -131,7 +131,7 @@ def test_handle_event_repeated(): assert set(event_id.keys()) == set(['aud/vis']) assert event_id['aud/vis'] == 5 assert_array_equal(selection, [0]) - assert drop_log[1] == ['MERGE DUPLICATE'] + assert drop_log[1] == ('MERGE DUPLICATE',) assert_array_equal(events, np.array([[0, 0, 5], ])) del heterogeneous_events @@ -144,7 +144,7 @@ def test_handle_event_repeated(): assert set(event_id.keys()) == set(['aud', 'vis', 'aud/vis']) assert_array_equal(events, np.array([[0, 99, 4], [1, 0, 1], [2, 0, 2]])) assert_array_equal(selection, [1, 4, 7]) - assert drop_log[3] == ['MERGE DUPLICATE'] + assert drop_log[3] == ('MERGE DUPLICATE',) del homogeneous_events # Test dropping instead of merging, if event_codes to be merged are equal @@ -153,7 +153,7 @@ def test_handle_event_repeated(): equal_events, EVENT_ID, 'merge', [3, 5], deepcopy(DROP_LOG)) assert_array_equal(events, np.array([[0, 0, 1], ])) assert_array_equal(selection, [3]) - assert drop_log[5] == ['MERGE DUPLICATE'] + assert drop_log[5] == ('MERGE DUPLICATE',) assert set(event_id.keys()) == set(['aud']) @@ -287,13 +287,23 @@ def test_average_movements(): head_pos=head_pos) # prj +def _assert_drop_log_types(drop_log): + __tracebackhide__ = True + assert isinstance(drop_log, tuple), 'drop_log should be tuple' + assert all(isinstance(log, tuple) for log in drop_log), \ + 'drop_log[ii] should be tuple' + assert all(isinstance(s, str) for log in drop_log for s in log), \ + 'drop_log[ii][jj] should be str' + + def test_reject(): """Test epochs rejection.""" raw, events, picks = _get_data() # cull the list just to contain the relevant event events = events[events[:, 2] == event_id, :] selection = np.arange(3) - drop_log = [[]] * 3 + [['MEG 2443']] * 4 + drop_log = ((),) * 3 + (('MEG 2443',),) * 4 + _assert_drop_log_types(drop_log) pytest.raises(TypeError, pick_types, raw) picks_meg = pick_types(raw.info, meg=True, eeg=False) pytest.raises(TypeError, Epochs, raw, events, event_id, tmin, tmax, @@ -318,11 +328,12 @@ def test_reject(): # no rejection epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=preload) + _assert_drop_log_types(epochs.drop_log) pytest.raises(ValueError, epochs.drop_bad, reject='foo') epochs.drop_bad() assert_equal(len(epochs), len(events)) assert_array_equal(epochs.selection, np.arange(len(events))) - assert epochs.drop_log == [[]] * 7 + assert epochs.drop_log == ((),) * 7 if proj not in data_7: data_7[proj] = epochs.get_data() assert_array_equal(epochs.get_data(), data_7[proj]) @@ -330,7 +341,9 @@ def test_reject(): # with rejection epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, reject=reject, preload=preload) + _assert_drop_log_types(epochs.drop_log) epochs.drop_bad() + _assert_drop_log_types(epochs.drop_log) assert_equal(len(epochs), len(events) - 4) assert_array_equal(epochs.selection, selection) assert epochs.drop_log == drop_log @@ -1260,10 +1273,10 @@ def test_reject_epochs(): # Should match # mne_process_raw --raw test_raw.fif --projoff \ # --saveavetag -ave --ave test.ave --filteroff - assert (n_events > n_clean_epochs) - assert (n_clean_epochs == 3) - assert (epochs.drop_log == [[], [], [], ['MEG 2443'], ['MEG 2443'], - ['MEG 2443'], ['MEG 2443']]) + assert n_events > n_clean_epochs + assert n_clean_epochs == 3 + assert epochs.drop_log == ((), (), (), ('MEG 2443',), ('MEG 2443',), + ('MEG 2443',), ('MEG 2443',)) # Ensure epochs are not dropped based on a bad channel raw_2 = raw.copy() @@ -1288,18 +1301,17 @@ def test_reject_epochs(): reject=reject, flat=flat, reject_tmin=0., reject_tmax=.1) data = epochs.get_data() n_clean_epochs = len(data) - assert (n_clean_epochs == 7) - assert (len(epochs) == 7) - assert (epochs.times[epochs._reject_time][0] >= 0.) - assert (epochs.times[epochs._reject_time][-1] <= 0.1) + assert n_clean_epochs == 7 + assert len(epochs) == 7 + assert epochs.times[epochs._reject_time][0] >= 0. + assert epochs.times[epochs._reject_time][-1] <= 0.1 # Invalid data for _is_good_epoch function epochs = Epochs(raw, events1, event_id, tmin, tmax) - assert_equal(epochs._is_good_epoch(None), (False, ['NO_DATA'])) - assert_equal(epochs._is_good_epoch(np.zeros((1, 1))), - (False, ['TOO_SHORT'])) + assert epochs._is_good_epoch(None) == (False, ('NO_DATA',)) + assert epochs._is_good_epoch(np.zeros((1, 1))) == (False, ('TOO_SHORT',)) data = epochs[0].get_data()[0] - assert_equal(epochs._is_good_epoch(data), (True, None)) + assert epochs._is_good_epoch(data) == (True, None) def test_preload_epochs(): @@ -1640,15 +1652,14 @@ def test_epoch_eq(): """Test epoch count equalization and condition combining.""" raw, events, picks = _get_data() # equalizing epochs objects - epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks) - epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks) + events_1 = events[events[:, 2] == event_id] + epochs_1 = Epochs(raw, events_1, event_id, tmin, tmax, picks=picks) + events_2 = events[events[:, 2] == event_id_2] + epochs_2 = Epochs(raw, events_2, event_id_2, tmin, tmax, picks=picks) epochs_1.drop_bad() # make sure drops are logged assert_equal(len([log for log in epochs_1.drop_log if not log]), len(epochs_1.events)) - drop_log1 = epochs_1.drop_log = [[] for _ in range(len(epochs_1.events))] - drop_log2 = [[] if log == ['EQUALIZED_COUNT'] else log for log in - epochs_1.drop_log] - assert_equal(drop_log1, drop_log2) + assert epochs_1.drop_log == ((),) * len(epochs_1.events) assert_equal(len([lg for lg in epochs_1.drop_log if not lg]), len(epochs_1.events)) assert (epochs_1.events.shape[0] != epochs_2.events.shape[0]) @@ -1670,8 +1681,8 @@ def test_epoch_eq(): old_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']] epochs.equalize_event_counts(['a', 'b']) # undo the eq logging - drop_log2 = [[] if log == ['EQUALIZED_COUNT'] else log for log in - epochs.drop_log] + drop_log2 = tuple(() if log == ('EQUALIZED_COUNT',) else log + for log in epochs.drop_log) assert_equal(drop_log1, drop_log2) assert_equal(len([log for log in epochs.drop_log if not log]), @@ -2024,7 +2035,7 @@ def test_drop_epochs(): assert_array_equal(epochs.selection, np.where(events[:, 2] == event_id)[0]) assert_equal(len(epochs.drop_log), len(events)) - assert (all(epochs.drop_log[k] == ['IGNORED'] + assert (all(epochs.drop_log[k] == ('IGNORED',) for k in set(range(len(events))) - set(epochs.selection))) selection = epochs.selection.copy() @@ -2054,10 +2065,10 @@ def test_drop_epochs_mult(): # In the preload case you cannot know the bads if already ignored assert_equal(len(epochs1.drop_log), len(epochs2.drop_log)) for d1, d2 in zip(epochs1.drop_log, epochs2.drop_log): - if d1 == ['IGNORED']: - assert (d2 == ['IGNORED']) - if d1 != ['IGNORED'] and d1 != []: - assert ((d2 == d1) or (d2 == ['IGNORED'])) + if d1 == ('IGNORED',): + assert (d2 == ('IGNORED',)) + if d1 != ('IGNORED',) and d1 != []: + assert ((d2 == d1) or (d2 == ('IGNORED',))) if d1 == []: assert (d2 == []) assert_array_equal(epochs1.events, epochs2.events) diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 5d787fce4fe..05e2cf8eb33 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -182,10 +182,13 @@ def _getitem(self, item, reason='IGNORED', copy=True, drop_event_id=True, has_selection = hasattr(inst, 'selection') if has_selection: key_selection = inst.selection[select] + drop_log = list(inst.drop_log) if reason is not None: for k in np.setdiff1d(inst.selection, key_selection): - inst.drop_log[k] = [reason] + drop_log[k] = (reason,) + inst.drop_log = tuple(drop_log) inst.selection = key_selection + del drop_log inst.events = np.atleast_2d(inst.events[select]) if inst.metadata is not None: diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 3c02346bad2..5653188c36b 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -269,9 +269,12 @@ def test_plot_drop_log(): pytest.raises(ValueError, epochs.plot_drop_log) epochs.drop_bad() epochs.plot_drop_log() - plot_drop_log([['One'], [], []]) - plot_drop_log([['One'], ['Two'], []]) - plot_drop_log([['One'], ['One', 'Two'], []]) + plot_drop_log((('One',), (), ())) + plot_drop_log((('One',), ('Two',), ())) + plot_drop_log((('One',), ('One', 'Two'), ())) + for arg in ([], ([],), (1,)): + with pytest.raises(TypeError, match='tuple of tuple of str'): + plot_drop_log(arg) plt.close('all')