diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index b0d44a725c3..65525514cf1 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -141,6 +141,8 @@ Changelog - :class:`mne.Report` now can add topomaps of SSP projectors to the generated report. This behavior can be toggled via the new ``projs`` argument by `Richard Höchenberger`_ +- Add function :func:`mne.channels.combine_channels` to combine channels from Raw, Epochs, or Evoked according to ROIs (combinations including mean, median, or standard deviation; can also use a callable) by `Johann Benerradi`_ + Bug ~~~ diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 3e9471b7b5b..7f027dec383 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -309,3 +309,5 @@ .. _Ezequiel Mikulan: https://github.com/ezemikulan .. _Jan Sedivy: https://github.com/honzaseda + +.. _Johann Benerradi: https://github.com/HanBnrd diff --git a/doc/python_reference.rst b/doc/python_reference.rst index 19775a902a2..d997411f3ab 100644 --- a/doc/python_reference.rst +++ b/doc/python_reference.rst @@ -339,6 +339,7 @@ Projections: rename_channels generate_2d_layout make_1020_channel_selections + combine_channels :py:mod:`mne.preprocessing`: diff --git a/mne/channels/__init__.py b/mne/channels/__init__.py index 771a3628d18..afa26af4818 100644 --- a/mne/channels/__init__.py +++ b/mne/channels/__init__.py @@ -15,8 +15,8 @@ read_custom_montage, read_dig_hpts, compute_native_head_t) from .channels import (equalize_channels, rename_channels, fix_mag_coil_types, - read_ch_adjacency, _get_ch_type, - find_ch_adjacency, make_1020_channel_selections) + read_ch_adjacency, _get_ch_type, find_ch_adjacency, + make_1020_channel_selections, combine_channels) from ..utils import deprecated_alias deprecated_alias('read_ch_connectivity', read_ch_adjacency) deprecated_alias('find_ch_connectivity', find_ch_adjacency) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 24a7049529f..41f13a3bcb2 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -10,6 +10,9 @@ import os import os.path as op import sys +from collections import OrderedDict +from copy import deepcopy +from functools import partial import numpy as np from scipy import sparse @@ -19,7 +22,7 @@ fill_doc, _check_option) from ..io.compensator import get_current_comp from ..io.constants import FIFF -from ..io.meas_info import anonymize_info, Info, MontageMixin +from ..io.meas_info import anonymize_info, Info, MontageMixin, create_info from ..io.pick import (channel_type, pick_info, pick_types, _picks_by_type, _check_excludes_includes, _contains_ch_type, channel_indices_by_type, pick_channels, _picks_to_idx, @@ -1578,3 +1581,141 @@ def make_1020_channel_selections(info, midline="z"): for selection, picks in selections.items()} return selections + + +def combine_channels(inst, groups, method='mean', keep_stim=False, + drop_bad=False): + """Combine channels based on specified channel grouping. + + Parameters + ---------- + inst : instance of Raw, Epochs, or Evoked + An MNE-Python object to combine the channels for. The object can be of + type Raw, Epochs, or Evoked. + groups : dict + Specifies which channels are aggregated into a single channel, with + aggregation method determined by the ``method`` parameter. One new + pseudo-channel is made per dict entry; the dict values must be lists of + picks (integer indices of ``ch_names``). For example:: + + groups=dict(Left=[1, 2, 3, 4], Right=[5, 6, 7, 8]) + + Note that within a dict entry all channels must have the same type. + method : str | callable + Which method to use to combine channels. If a :class:`str`, must be one + of 'mean', 'median', or 'std' (standard deviation). If callable, the + callable must accept one positional input (data of shape ``(n_channels, + n_times)``, or ``(n_epochs, n_channels, n_times)``) and return an + :class:`array ` of shape ``(n_times,)``, or ``(n_epochs, + n_times)``. For example with an instance of Raw or Evoked:: + + method = lambda data: np.mean(data, axis=0) + + Another example with an instance of Epochs:: + + method = lambda data: np.median(data, axis=1) + + Defaults to ``'mean'``. + keep_stim : bool + If ``True``, include stimulus channels in the resulting object. + Defaults to ``False``. + drop_bad : bool + If ``True``, drop channels marked as bad before combining. Defaults to + ``False``. + + Returns + ------- + combined_inst : instance of Raw, Epochs, or Evoked + An MNE-Python object of the same type as the input ``inst``, containing + one virtual channel for each group in ``groups`` (and, if ``keep_stim`` + is ``True``, also containing stimulus channels). + """ + from ..io import BaseRaw, RawArray + from .. import BaseEpochs, EpochsArray, Evoked, EvokedArray + + ch_axis = 1 if isinstance(inst, BaseEpochs) else 0 + ch_idx = list(range(inst.info['nchan'])) + ch_names = inst.info['ch_names'] + ch_types = inst.get_channel_types() + inst_data = inst.data if isinstance(inst, Evoked) else inst.get_data() + groups = OrderedDict(deepcopy(groups)) + + # Convert string values of ``method`` into callables + # XXX Possibly de-duplicate with _make_combine_callable of mne/viz/utils.py + if isinstance(method, str): + method_dict = {key: partial(getattr(np, key), axis=ch_axis) + for key in ('mean', 'median', 'std')} + try: + method = method_dict[method] + except KeyError: + raise ValueError('"method" must be a callable, or one of "mean", ' + f'"median", or "std"; got "{method}".') + + # Instantiate channel info and data + new_ch_names, new_ch_types, new_data = [], [], [] + if not isinstance(keep_stim, bool): + raise TypeError('"keep_stim" must be of type bool, not ' + f'{type(keep_stim)}.') + if keep_stim: + stim_ch_idx = list(pick_types(inst.info, meg=False, stim=True)) + if stim_ch_idx: + new_ch_names = [ch_names[idx] for idx in stim_ch_idx] + new_ch_types = [ch_types[idx] for idx in stim_ch_idx] + new_data = [np.take(inst_data, idx, axis=ch_axis) + for idx in stim_ch_idx] + else: + warn('Could not find stimulus channels.') + + # Get indices of bad channels + ch_idx_bad = [] + if not isinstance(drop_bad, bool): + raise TypeError('"drop_bad" must be of type bool, not ' + f'{type(drop_bad)}.') + if drop_bad and inst.info['bads']: + ch_idx_bad = pick_channels(ch_names, inst.info['bads']) + + # Check correctness of combinations + for this_group, this_picks in groups.items(): + # Check if channel indices are out of bounds + if not all(idx in ch_idx for idx in this_picks): + raise ValueError('Some channel indices are out of bounds.') + # Check if heterogeneous sensor type combinations + this_ch_type = np.array(ch_types)[this_picks] + if len(set(this_ch_type)) > 1: + types = ', '.join(set(this_ch_type)) + raise ValueError('Cannot combine sensors of different types; ' + f'"{this_group}" contains types {types}.') + # Remove bad channels + these_bads = [idx for idx in this_picks if idx in ch_idx_bad] + this_picks = [idx for idx in this_picks if idx not in ch_idx_bad] + if these_bads: + logger.info('Dropped the following channels in group ' + f'{this_group}: {these_bads}') + # Check if combining less than 2 channel + if len(set(this_picks)) < 2: + warn(f'Less than 2 channels in group "{this_group}" when ' + f'combining by method "{method}".') + # If all good create more detailed dict without bad channels + groups[this_group] = dict(picks=this_picks, ch_type=this_ch_type[0]) + + # Combine channels and add them to the new instance + for this_group, this_group_dict in groups.items(): + new_ch_names.append(this_group) + new_ch_types.append(this_group_dict['ch_type']) + this_picks = this_group_dict['picks'] + this_data = np.take(inst_data, this_picks, axis=ch_axis) + new_data.append(method(this_data)) + new_data = np.swapaxes(new_data, 0, ch_axis) + info = create_info(sfreq=inst.info['sfreq'], ch_names=new_ch_names, + ch_types=new_ch_types) + if isinstance(inst, BaseRaw): + combined_inst = RawArray(new_data, info, first_samp=inst.first_samp, + verbose=inst.verbose) + elif isinstance(inst, BaseEpochs): + combined_inst = EpochsArray(new_data, info, tmin=inst.times[0], + verbose=inst.verbose) + elif isinstance(inst, Evoked): + combined_inst = EvokedArray(new_data, info, tmin=inst.times[0], + verbose=inst.verbose) + + return combined_inst diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index d013d880a45..e4f9b917777 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -13,7 +13,7 @@ from scipy.io import savemat from numpy.testing import assert_array_equal, assert_equal -from mne.channels import (rename_channels, read_ch_adjacency, +from mne.channels import (rename_channels, read_ch_adjacency, combine_channels, find_ch_adjacency, make_1020_channel_selections, read_custom_montage, equalize_channels) from mne.channels.channels import (_ch_neighbor_adjacency, @@ -23,12 +23,13 @@ from mne.io.constants import FIFF from mne.utils import _TempDir, run_tests_if_main from mne import (pick_types, pick_channels, EpochsArray, EvokedArray, - make_ad_hoc_cov, create_info) + make_ad_hoc_cov, create_info, read_events, Epochs) from mne.datasets import testing io_dir = op.join(op.dirname(__file__), '..', '..', 'io') base_dir = op.join(io_dir, 'tests', 'data') raw_fname = op.join(base_dir, 'test_raw.fif') +eve_fname = op .join(base_dir, 'test-eve.fif') fname_kit_157 = op.join(io_dir, 'kit', 'tests', 'data', 'test.sqd') @@ -360,4 +361,69 @@ def test_equalize_channels(): assert epochs is epochs2 +def test_combine_channels(): + """Test channel combination on Raw, Epochs, and Evoked.""" + raw = read_raw_fif(raw_fname, preload=True) + raw_ch_bad = read_raw_fif(raw_fname, preload=True) + raw_ch_bad.info['bads'] = ['MEG 0113', 'MEG 0112'] + epochs = Epochs(raw, read_events(eve_fname)) + evoked = epochs.average() + good = dict(foo=[0, 1, 3, 4], bar=[5, 2]) # good grad and mag + + # Test good cases + combine_channels(raw, good) + combine_channels(epochs, good) + combine_channels(evoked, good) + combine_channels(raw, good, drop_bad=True) + combine_channels(raw_ch_bad, good, drop_bad=True) + + # Test with stimulus channels + combine_stim = combine_channels(raw, good, keep_stim=True) + target_nchan = len(good) + len(pick_types(raw.info, meg=False, stim=True)) + assert combine_stim.info['nchan'] == target_nchan + + # Test results with one ROI + good_single = dict(foo=[0, 1, 3, 4]) # good grad + combined_mean = combine_channels(raw, good_single, method='mean') + combined_median = combine_channels(raw, good_single, method='median') + combined_std = combine_channels(raw, good_single, method='std') + foo_mean = np.mean(raw.get_data()[good_single['foo']], axis=0) + foo_median = np.median(raw.get_data()[good_single['foo']], axis=0) + foo_std = np.std(raw.get_data()[good_single['foo']], axis=0) + assert np.array_equal(combined_mean.get_data(), + np.expand_dims(foo_mean, axis=0)) + assert np.array_equal(combined_median.get_data(), + np.expand_dims(foo_median, axis=0)) + assert np.array_equal(combined_std.get_data(), + np.expand_dims(foo_std, axis=0)) + + # Test bad cases + bad1 = dict(foo=[0, 376], bar=[5, 2]) # out of bounds + bad2 = dict(foo=[0, 2], bar=[5, 2]) # type mix in same group + with pytest.raises(ValueError, match='"method" must be a callable, or'): + combine_channels(raw, good, method='bad_method') + with pytest.raises(TypeError, match='"keep_stim" must be of type bool'): + combine_channels(raw, good, keep_stim='bad_type') + with pytest.raises(TypeError, match='"drop_bad" must be of type bool'): + combine_channels(raw, good, drop_bad='bad_type') + with pytest.raises(ValueError, match='Some channel indices are out of'): + combine_channels(raw, bad1) + with pytest.raises(ValueError, match='Cannot combine sensors of diff'): + combine_channels(raw, bad2) + + # Test warnings + raw_no_stim = read_raw_fif(raw_fname, preload=True) + raw_no_stim.pick_types(meg=True, stim=False) + warn1 = dict(foo=[375, 375], bar=[5, 2]) # same channel in same group + warn2 = dict(foo=[375], bar=[5, 2]) # one channel (last channel) + warn3 = dict(foo=[0, 4], bar=[5, 2]) # one good channel left + with pytest.warns(RuntimeWarning, match='Could not find stimulus'): + combine_channels(raw_no_stim, good, keep_stim=True) + with pytest.warns(RuntimeWarning, match='Less than 2 channels') as record: + combine_channels(raw, warn1) + combine_channels(raw, warn2) + combine_channels(raw_ch_bad, warn3, drop_bad=True) + assert len(record) == 3 + + run_tests_if_main() diff --git a/tutorials/evoked/plot_eeg_erp.py b/tutorials/evoked/plot_eeg_erp.py index 8ffbdbac88a..d18f4d31171 100644 --- a/tutorials/evoked/plot_eeg_erp.py +++ b/tutorials/evoked/plot_eeg_erp.py @@ -12,6 +12,7 @@ import mne from mne.datasets import sample +from mne.channels import combine_channels ############################################################################### # Setup for reading the raw data @@ -107,6 +108,26 @@ evoked_custom.plot(titles=dict(eeg=title), time_unit='s') evoked_custom.plot_topomap(times=[0.1], size=3., title=title, time_unit='s') +############################################################################### +# Evoked response averaged across channels by ROI +# ----------------------------------------------- +# +# It is possible to average channels by region of interest (for example left +# and right) when studying the response to this left auditory stimulus. Here we +# use our Raw object on which the average reference projection has been added +# back. +evoked = mne.Epochs(raw, **epochs_params).average() + +left_idx = mne.pick_channels(evoked.info['ch_names'], + ['EEG 017', 'EEG 018', 'EEG 025', 'EEG 026']) +right_idx = mne.pick_channels(evoked.info['ch_names'], + ['EEG 023', 'EEG 024', 'EEG 034', 'EEG 035']) +roi_dict = dict(Left=left_idx, Right=right_idx) +evoked_combined = combine_channels(evoked, roi_dict, method='mean') + +title = 'Evoked response averaged by side' +evoked_combined.plot(titles=dict(eeg=title), time_unit='s') + ############################################################################### # Evoked arithmetic (e.g. differences) # ------------------------------------