From a85093cdd523726d62629724344fb941c3e03908 Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Mon, 25 Sep 2023 14:44:04 -0400 Subject: [PATCH 01/38] adding unify channels to preproc --- mne/preprocessing/unify_bads.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 mne/preprocessing/unify_bads.py diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py new file mode 100644 index 00000000000..6b24d2d078a --- /dev/null +++ b/mne/preprocessing/unify_bads.py @@ -0,0 +1,32 @@ + +#%% +def unifying_bads( + list_instances, + +): + + common_bad_channels = [] + #first check that each object is mne object + + #then interate through the objects to get ch names as set + ch_set_1 = list_instances[0].info['bads'] + common_bad_channels.extend(ch_set_1) + + for inst in list_instances[1:]: + ch_set_2 = set(inst.info['bads']) + set_of_bads = set(common_bad_channels) + new_bads = ch_set_2.difference(set_of_bads) + + if len(new_bads) >1 : + common_bad_channels.extend(list(new_bads)) + + new_instances = [] + + for inst in list_instances: + inst.info["bads"] = common_bad_channels + new_instances.append(inst) + + + return new_instances + + From 02ecd49918c9fb320e18734881f7e679ac84325e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 18:46:50 +0000 Subject: [PATCH 02/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/unify_bads.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py index 6b24d2d078a..1145192aa88 100644 --- a/mne/preprocessing/unify_bads.py +++ b/mne/preprocessing/unify_bads.py @@ -1,32 +1,26 @@ - -#%% +# %% def unifying_bads( list_instances, - ): - common_bad_channels = [] - #first check that each object is mne object - - #then interate through the objects to get ch names as set - ch_set_1 = list_instances[0].info['bads'] + # first check that each object is mne object + + # then interate through the objects to get ch names as set + ch_set_1 = list_instances[0].info["bads"] common_bad_channels.extend(ch_set_1) - + for inst in list_instances[1:]: - ch_set_2 = set(inst.info['bads']) + ch_set_2 = set(inst.info["bads"]) set_of_bads = set(common_bad_channels) new_bads = ch_set_2.difference(set_of_bads) - - if len(new_bads) >1 : + + if len(new_bads) > 1: common_bad_channels.extend(list(new_bads)) - + new_instances = [] - + for inst in list_instances: inst.info["bads"] = common_bad_channels new_instances.append(inst) - return new_instances - - From 0284caaab74d51026a49765f96f51df1c4f72d66 Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Mon, 25 Sep 2023 17:13:28 -0400 Subject: [PATCH 03/38] Update mne/preprocessing/unify_bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/unify_bads.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py index 1145192aa88..1c0cf1b6793 100644 --- a/mne/preprocessing/unify_bads.py +++ b/mne/preprocessing/unify_bads.py @@ -4,7 +4,14 @@ def unifying_bads( ): common_bad_channels = [] # first check that each object is mne object - + inst_types = set(type(insts[0])) + valid_types = (Raw, Epochs, Evoked, Spectrum, EpochsSpectrum) + for inst in insts: + _validate_type(inst, valid_types, "instance type") + if type(inst) not in inst_types: + raise ValueError( + "all insts must be the same type" + ) # then interate through the objects to get ch names as set ch_set_1 = list_instances[0].info["bads"] common_bad_channels.extend(ch_set_1) From c3e979481ef2bd789969c576452ef5aa27524eae Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Mon, 25 Sep 2023 17:13:39 -0400 Subject: [PATCH 04/38] Update mne/preprocessing/unify_bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/unify_bads.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py index 1c0cf1b6793..b5ea971f9b6 100644 --- a/mne/preprocessing/unify_bads.py +++ b/mne/preprocessing/unify_bads.py @@ -15,7 +15,10 @@ def unifying_bads( # then interate through the objects to get ch names as set ch_set_1 = list_instances[0].info["bads"] common_bad_channels.extend(ch_set_1) - + all_bads = dict() + for inst in insts: + all_bads.update(dict.fromkeys(inst.info['bads'])) + all_bads = list(all_bads) for inst in list_instances[1:]: ch_set_2 = set(inst.info["bads"]) set_of_bads = set(common_bad_channels) From aeb7a42495c8422d40d5d6810a15d00535ba754d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 21:14:20 +0000 Subject: [PATCH 05/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/unify_bads.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py index b5ea971f9b6..a75034f390e 100644 --- a/mne/preprocessing/unify_bads.py +++ b/mne/preprocessing/unify_bads.py @@ -9,9 +9,7 @@ def unifying_bads( for inst in insts: _validate_type(inst, valid_types, "instance type") if type(inst) not in inst_types: - raise ValueError( - "all insts must be the same type" - ) + raise ValueError("all insts must be the same type") # then interate through the objects to get ch names as set ch_set_1 = list_instances[0].info["bads"] common_bad_channels.extend(ch_set_1) From 32f8b35e300f9a4c9031c0e952b8d8c76c310892 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 21:15:16 +0000 Subject: [PATCH 06/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/unify_bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py index a75034f390e..be2ca2ac198 100644 --- a/mne/preprocessing/unify_bads.py +++ b/mne/preprocessing/unify_bads.py @@ -15,7 +15,7 @@ def unifying_bads( common_bad_channels.extend(ch_set_1) all_bads = dict() for inst in insts: - all_bads.update(dict.fromkeys(inst.info['bads'])) + all_bads.update(dict.fromkeys(inst.info["bads"])) all_bads = list(all_bads) for inst in list_instances[1:]: ch_set_2 = set(inst.info["bads"]) From c22295658ae7ab474bc3ba4b826c23854b3df34a Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Mon, 25 Sep 2023 18:46:06 -0400 Subject: [PATCH 07/38] updates to inprogress work unify bads --- mne/preprocessing/unify_bads.py | 38 +++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py index be2ca2ac198..aa32e403fa4 100644 --- a/mne/preprocessing/unify_bads.py +++ b/mne/preprocessing/unify_bads.py @@ -1,34 +1,36 @@ -# %% -def unifying_bads( - list_instances, +#%% + +from ..utils import _validate_type + +from ..io import BaseRaw +from ..epochs import Epochs +from ..evoked import Evoked +from ..time_frequency.spectrum import BaseSpectrum + +#%% +def unify_bad_channels( + insts, ): - common_bad_channels = [] # first check that each object is mne object - inst_types = set(type(insts[0])) - valid_types = (Raw, Epochs, Evoked, Spectrum, EpochsSpectrum) + inst_type = type(insts[0]) + valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) for inst in insts: - _validate_type(inst, valid_types, "instance type") - if type(inst) not in inst_types: + _validate_type(inst, valid_types , "instance type") + if type(inst) != inst_type: raise ValueError("all insts must be the same type") # then interate through the objects to get ch names as set - ch_set_1 = list_instances[0].info["bads"] - common_bad_channels.extend(ch_set_1) + all_bads = dict() for inst in insts: all_bads.update(dict.fromkeys(inst.info["bads"])) all_bads = list(all_bads) - for inst in list_instances[1:]: - ch_set_2 = set(inst.info["bads"]) - set_of_bads = set(common_bad_channels) - new_bads = ch_set_2.difference(set_of_bads) - if len(new_bads) > 1: - common_bad_channels.extend(list(new_bads)) new_instances = [] - for inst in list_instances: - inst.info["bads"] = common_bad_channels + for inst in insts: + inst.info["bads"] = all_bads new_instances.append(inst) return new_instances + From 0fa5b18f0c1d327a9d11dca5866dfe81d6fd0bd1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 22:49:36 +0000 Subject: [PATCH 08/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/preprocessing/unify_bads.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py index aa32e403fa4..c8b7d08e8e0 100644 --- a/mne/preprocessing/unify_bads.py +++ b/mne/preprocessing/unify_bads.py @@ -1,4 +1,4 @@ -#%% +# %% from ..utils import _validate_type @@ -7,7 +7,8 @@ from ..evoked import Evoked from ..time_frequency.spectrum import BaseSpectrum -#%% + +# %% def unify_bad_channels( insts, ): @@ -15,7 +16,7 @@ def unify_bad_channels( inst_type = type(insts[0]) valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) for inst in insts: - _validate_type(inst, valid_types , "instance type") + _validate_type(inst, valid_types, "instance type") if type(inst) != inst_type: raise ValueError("all insts must be the same type") # then interate through the objects to get ch names as set @@ -25,7 +26,6 @@ def unify_bad_channels( all_bads.update(dict.fromkeys(inst.info["bads"])) all_bads = list(all_bads) - new_instances = [] for inst in insts: @@ -33,4 +33,3 @@ def unify_bad_channels( new_instances.append(inst) return new_instances - From 2c200a1dfc94a7667cc6e3c65425853e77b5e409 Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Tue, 26 Sep 2023 15:53:01 -0400 Subject: [PATCH 09/38] Move unify_bad_channels function to preprocessing/bads.py --- mne/preprocessing/bads.py | 66 +++++++++++++++++++++++++++++++++ mne/preprocessing/unify_bads.py | 36 ------------------ 2 files changed, 66 insertions(+), 36 deletions(-) delete mode 100644 mne/preprocessing/unify_bads.py diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 077145ff7c8..72a0a72c250 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -1,10 +1,19 @@ # Authors: Denis Engemann +# Ana Radanovic +# Erica Peterson # License: BSD-3-Clause import numpy as np from scipy.stats import zscore +from ..utils import _validate_type + +from ..io import BaseRaw +from ..epochs import Epochs +from ..evoked import Evoked +from ..time_frequency.spectrum import BaseSpectrum + def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): """Find outliers based on iterated Z-scoring. @@ -47,3 +56,60 @@ def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): bad_idx = np.where(my_mask)[0] return bad_idx + + +def unify_bad_channels(insts): + """Unify bad channels across list of instances. + + This function looks across the list of instances to gather a list of + "bad" channels. Every instance's info["bads"] is set with the same "bad" + channels. + + Parameters + ---------- + insts : list + List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, + :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, + :class:`~mne.time_frequency.EpochSpectrum`) to unify bad channels. + + Returns + ------- + new_insts : list + List of instances with bad channels unified across instances. + + Notes + ----- + This function operates in-place. + + .. versionadded:: 1.6 + """ + # first check that each object is mne object + inst_type = type(insts[0]) + valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) + for inst in insts: + _validate_type(inst, valid_types , "instance type") + if type(inst) != inst_type: + raise ValueError("All insts must be the same type") + + #check that input is a list + if type(insts) != list : + raise ValueError(f"insts must be a *list* of mne objects, got {type(insts)}") + + if len(insts) == 0: + raise ValueError("Be sure insts is not empty list") + + # then interate through the insts to gather bads + all_bads = dict() + for inst in insts: + #using dictionary method to remove duplicates & preserve order + all_bads.update(dict.fromkeys(inst.info["bads"])) + all_bads = list(all_bads) + + #apply bads set to all instances + new_insts = [] + + for inst in insts: + inst.info["bads"] = all_bads + new_insts.append(inst) + + return new_insts diff --git a/mne/preprocessing/unify_bads.py b/mne/preprocessing/unify_bads.py deleted file mode 100644 index aa32e403fa4..00000000000 --- a/mne/preprocessing/unify_bads.py +++ /dev/null @@ -1,36 +0,0 @@ -#%% - -from ..utils import _validate_type - -from ..io import BaseRaw -from ..epochs import Epochs -from ..evoked import Evoked -from ..time_frequency.spectrum import BaseSpectrum - -#%% -def unify_bad_channels( - insts, -): - # first check that each object is mne object - inst_type = type(insts[0]) - valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) - for inst in insts: - _validate_type(inst, valid_types , "instance type") - if type(inst) != inst_type: - raise ValueError("all insts must be the same type") - # then interate through the objects to get ch names as set - - all_bads = dict() - for inst in insts: - all_bads.update(dict.fromkeys(inst.info["bads"])) - all_bads = list(all_bads) - - - new_instances = [] - - for inst in insts: - inst.info["bads"] = all_bads - new_instances.append(inst) - - return new_instances - From 80c329d1dc5eede1c1345f18404130953e3eccd3 Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Tue, 26 Sep 2023 16:05:26 -0400 Subject: [PATCH 10/38] add function to namespace --- mne/preprocessing/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/preprocessing/__init__.py b/mne/preprocessing/__init__.py index 8358f006e09..8aea05c369f 100644 --- a/mne/preprocessing/__init__.py +++ b/mne/preprocessing/__init__.py @@ -56,5 +56,6 @@ "interpolate": ["equalize_bads", "interpolate_bridged_electrodes"], "_css": ["cortical_signal_suppression"], "hfc": ["compute_proj_hfc"], + "bads": ["unify_bad_channels"] }, ) From fc917873cec6b10b99e33929dbe1103cdca1e796 Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Tue, 26 Sep 2023 16:20:46 -0400 Subject: [PATCH 11/38] style fixes --- mne/preprocessing/__init__.py | 2 +- mne/preprocessing/bads.py | 31 ++++++++++++++++--------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/mne/preprocessing/__init__.py b/mne/preprocessing/__init__.py index 8aea05c369f..6d818ac60b0 100644 --- a/mne/preprocessing/__init__.py +++ b/mne/preprocessing/__init__.py @@ -56,6 +56,6 @@ "interpolate": ["equalize_bads", "interpolate_bridged_electrodes"], "_css": ["cortical_signal_suppression"], "hfc": ["compute_proj_hfc"], - "bads": ["unify_bad_channels"] + "bads": ["unify_bad_channels"], }, ) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 72a0a72c250..453287dd311 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -14,6 +14,7 @@ from ..evoked import Evoked from ..time_frequency.spectrum import BaseSpectrum + def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): """Find outliers based on iterated Z-scoring. @@ -61,51 +62,51 @@ def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): def unify_bad_channels(insts): """Unify bad channels across list of instances. - This function looks across the list of instances to gather a list of - "bad" channels. Every instance's info["bads"] is set with the same "bad" + This function looks across the list of instances to gather a list of + "bad" channels. Every instance's info["bads"] is set with the same "bad" channels. Parameters ---------- insts : list - List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, - :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, + List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, + :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, :class:`~mne.time_frequency.EpochSpectrum`) to unify bad channels. Returns ------- new_insts : list List of instances with bad channels unified across instances. - + Notes ----- This function operates in-place. - + .. versionadded:: 1.6 """ # first check that each object is mne object inst_type = type(insts[0]) valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) for inst in insts: - _validate_type(inst, valid_types , "instance type") + _validate_type(inst, valid_types, "instance type") if type(inst) != inst_type: raise ValueError("All insts must be the same type") - - #check that input is a list - if type(insts) != list : + + # check that input is a list + if not isinstance(insts, list): raise ValueError(f"insts must be a *list* of mne objects, got {type(insts)}") - + if len(insts) == 0: raise ValueError("Be sure insts is not empty list") - - # then interate through the insts to gather bads + + # then iterate through the insts to gather bads all_bads = dict() for inst in insts: - #using dictionary method to remove duplicates & preserve order + # using dictionary method to remove duplicates & preserve order all_bads.update(dict.fromkeys(inst.info["bads"])) all_bads = list(all_bads) - #apply bads set to all instances + # apply bads set to all instances new_insts = [] for inst in insts: From 58b96c3ce8106f1ff5b88d973dbfc32c59a090a6 Mon Sep 17 00:00:00 2001 From: nordme Date: Tue, 26 Sep 2023 08:45:40 -0700 Subject: [PATCH 12/38] test draft --- mne/preprocessing/tests/test_unify_bads.py | 58 ++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 mne/preprocessing/tests/test_unify_bads.py diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/preprocessing/tests/test_unify_bads.py new file mode 100644 index 00000000000..8c696ba6e36 --- /dev/null +++ b/mne/preprocessing/tests/test_unify_bads.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import numpy as np +import pytest + +import mne +from mne.preprocessing import unify_bad_channels +from ..io import BaseRaw +from ..epochs import Epochs +from ..evoked import Evoked +from ..time_frequency.spectrum import BaseSpectrum + + +raw_fname = ( + Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" +) +@pytest.mark.parametrize('instance', ('raw', 'epochs', 'evoked', 'spectrum')) +def test_instance_support(instance): + # test unify_bads function on instance (single input, no bads scenario) + unify_bad_channels(instance) + + +def test_bads_order(raw): + +def test_unify_bads(raw, epochs): + ## test error raising + # err 1: no instance passed to function + with pytest.raises(UserWarning): # FIX RAISE TYPE + unify_bad_channels([]) + # err 2: bad instance passed to function + bad_inst = 'bad_instance' + with pytest.raises(UserWarning): # FIX RAISE TYPE + unify_bad_channels(bad_inst) + # err 3: mixed instance types passed to function + with pytest.raises(ValueError): + unify_bad_channels([raw, epochs]) + ## check unification scenarios + # scnario 1: single instance with actual bads (already tested no bads) + raw.info['bads'] += raw.info['ch_names'][0] + s_unified = unify_bad_channels(raw) + assert len(s_unified) == 1 + assert s_unified[0].info['bads'] == raw.info['ch_names'][0] + # scenario 2: multiple instances + # a) empty bads list, b) unique bads, c) overlapping out-of-order bads, + # d) channel name not included in raw channels + chns = raw.info['ch_names'][0:3] + raws = [raw, raw.copy(), raw.copy(), raw.copy(), raw.copy()] + assert raws[0].info['bads'] == [chns[0]] + raws[1].info['bads'] = [] + raws[2].info['bads'] = [chns[2]] + raws[3].info['bads'] = [chns[2], chns[1]] + raws[4].info['bads'] = ['nonsense_ch_name'] + # use unify_bads function + m_unified = unify_bad_channels(raws) + assert len(m_unified) == len(raws) + # check results + for i in np.arange(len(m_unified)): + assert m_unified[i].info['bads'] == From 7c28118dbf77ccc1f13e4443ad13aa35a40cb70b Mon Sep 17 00:00:00 2001 From: nordme Date: Tue, 26 Sep 2023 12:57:20 -0700 Subject: [PATCH 13/38] make functions work --- mne/preprocessing/tests/test_unify_bads.py | 67 +++++++++++----------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/preprocessing/tests/test_unify_bads.py index 8c696ba6e36..f01a9d58809 100644 --- a/mne/preprocessing/tests/test_unify_bads.py +++ b/mne/preprocessing/tests/test_unify_bads.py @@ -1,49 +1,45 @@ -from pathlib import Path - import numpy as np import pytest -import mne -from mne.preprocessing import unify_bad_channels -from ..io import BaseRaw -from ..epochs import Epochs -from ..evoked import Evoked -from ..time_frequency.spectrum import BaseSpectrum - +from mne.preprocessing.unify_bads import unify_bad_channels +from mne.time_frequency.tests.test_spectrum import _get_inst() -raw_fname = ( - Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" -) -@pytest.mark.parametrize('instance', ('raw', 'epochs', 'evoked', 'spectrum')) -def test_instance_support(instance): +@pytest.mark.parametrize('instance', ('raw', 'epochs', 'evoked', 'raw_spectrum', 'epochs_spectrum')) +def test_instance_support(instance, request, evoked): + '''Tests support of different classes.''' # test unify_bads function on instance (single input, no bads scenario) - unify_bad_channels(instance) + inst = _get_inst(instance, request, evoked) + inst_out = unify_bad_channels([inst]) + assert inst_out == [inst] + +@pytest.mark.parametrize('instance', ([], ['bad_instance'], 'mixed')) +def test_error_raising(instance, request, evoked): + '''Tests input checking.''' + if instance == 'mixed': + instance = [_get_inst('raw', request, evoked), + _get_inst('epochs', request, evoked)] + with pytest.raises(TypeError): + unify_bad_channels(instance) -def test_bads_order(raw): -def test_unify_bads(raw, epochs): - ## test error raising - # err 1: no instance passed to function - with pytest.raises(UserWarning): # FIX RAISE TYPE - unify_bad_channels([]) - # err 2: bad instance passed to function - bad_inst = 'bad_instance' - with pytest.raises(UserWarning): # FIX RAISE TYPE - unify_bad_channels(bad_inst) - # err 3: mixed instance types passed to function - with pytest.raises(ValueError): - unify_bad_channels([raw, epochs]) +def test_bads_compilation(raw): + '''Tests that bads are compiled properly in two cases: a) single instance + passed to function with an existing bad, and b) multiple instances passed + to function with varying compilation scenarios including empty bads, + unique bads, partially duplicated bads listed out-of-order, and + nonsense channel names.''' ## check unification scenarios - # scnario 1: single instance with actual bads (already tested no bads) - raw.info['bads'] += raw.info['ch_names'][0] - s_unified = unify_bad_channels(raw) + chns = raw.info['ch_names'][0:3] + # scnario 1: single instance passed with actual bads (already tested no bads) + assert raw.info['bads'] == [] + raw.info['bads'] += [chns[0]] + s_unified = unify_bad_channels([raw]) assert len(s_unified) == 1 - assert s_unified[0].info['bads'] == raw.info['ch_names'][0] - # scenario 2: multiple instances + assert s_unified[0].info['bads'] == [chns[0]], (s_unified[0].info['bads'], chns[0]) + # scenario 2: multiple instances passed, bads types as follows: # a) empty bads list, b) unique bads, c) overlapping out-of-order bads, # d) channel name not included in raw channels - chns = raw.info['ch_names'][0:3] raws = [raw, raw.copy(), raw.copy(), raw.copy(), raw.copy()] assert raws[0].info['bads'] == [chns[0]] raws[1].info['bads'] = [] @@ -54,5 +50,6 @@ def test_unify_bads(raw, epochs): m_unified = unify_bad_channels(raws) assert len(m_unified) == len(raws) # check results + correct_bads = [chns[0], chns[2], chns[1], 'nonsense_ch_name'] for i in np.arange(len(m_unified)): - assert m_unified[i].info['bads'] == + assert m_unified[i].info['bads'] == correct_bads From d89431e313e0595f16184be9f3cb7a94c93c8397 Mon Sep 17 00:00:00 2001 From: nordme Date: Tue, 26 Sep 2023 13:55:16 -0700 Subject: [PATCH 14/38] style fixes --- mne/preprocessing/tests/test_unify_bads.py | 59 ++++++++++++---------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/preprocessing/tests/test_unify_bads.py index f01a9d58809..96edf86c1ae 100644 --- a/mne/preprocessing/tests/test_unify_bads.py +++ b/mne/preprocessing/tests/test_unify_bads.py @@ -1,55 +1,62 @@ import numpy as np import pytest -from mne.preprocessing.unify_bads import unify_bad_channels -from mne.time_frequency.tests.test_spectrum import _get_inst() +from mne.preprocessing import unify_bad_channels +from mne.time_frequency.tests.test_spectrum import _get_inst -@pytest.mark.parametrize('instance', ('raw', 'epochs', 'evoked', 'raw_spectrum', 'epochs_spectrum')) + +@pytest.mark.parametrize( + "instance", ("raw", "epochs", "evoked", "raw_spectrum", "epochs_spectrum") +) def test_instance_support(instance, request, evoked): - '''Tests support of different classes.''' + """Tests support of different classes.""" # test unify_bads function on instance (single input, no bads scenario) inst = _get_inst(instance, request, evoked) inst_out = unify_bad_channels([inst]) assert inst_out == [inst] -@pytest.mark.parametrize('instance', ([], ['bad_instance'], 'mixed')) +@pytest.mark.parametrize("instance", ([], ["bad_instance"], "mixed")) def test_error_raising(instance, request, evoked): - '''Tests input checking.''' - if instance == 'mixed': - instance = [_get_inst('raw', request, evoked), - _get_inst('epochs', request, evoked)] + """Tests input checking.""" + if instance == "mixed": + instance = [ + _get_inst("raw", request, evoked), + _get_inst("epochs", request, evoked), + ] with pytest.raises(TypeError): unify_bad_channels(instance) def test_bads_compilation(raw): - '''Tests that bads are compiled properly in two cases: a) single instance - passed to function with an existing bad, and b) multiple instances passed - to function with varying compilation scenarios including empty bads, - unique bads, partially duplicated bads listed out-of-order, and - nonsense channel names.''' + """Tests that bads are compiled properly. + + Tests two cases: a) single instance passed to function with an existing + bad, and b) multiple instances passed to function with varying compilation + scenarios including empty bads, unique bads, partially duplicated bads + listed out-of-order, and nonsense channel names. + """ ## check unification scenarios - chns = raw.info['ch_names'][0:3] - # scnario 1: single instance passed with actual bads (already tested no bads) - assert raw.info['bads'] == [] - raw.info['bads'] += [chns[0]] + chns = raw.info["ch_names"][0:3] + # scenario 1: single instance passed with actual bads (already tested no bads) + assert raw.info["bads"] == [] + raw.info["bads"] += [chns[0]] s_unified = unify_bad_channels([raw]) assert len(s_unified) == 1 - assert s_unified[0].info['bads'] == [chns[0]], (s_unified[0].info['bads'], chns[0]) + assert s_unified[0].info["bads"] == [chns[0]], (s_unified[0].info["bads"], chns[0]) # scenario 2: multiple instances passed, bads types as follows: # a) empty bads list, b) unique bads, c) overlapping out-of-order bads, # d) channel name not included in raw channels raws = [raw, raw.copy(), raw.copy(), raw.copy(), raw.copy()] - assert raws[0].info['bads'] == [chns[0]] - raws[1].info['bads'] = [] - raws[2].info['bads'] = [chns[2]] - raws[3].info['bads'] = [chns[2], chns[1]] - raws[4].info['bads'] = ['nonsense_ch_name'] + assert raws[0].info["bads"] == [chns[0]] + raws[1].info["bads"] = [] + raws[2].info["bads"] = [chns[2]] + raws[3].info["bads"] = [chns[2], chns[1]] + raws[4].info["bads"] = ["nonsense_ch_name"] # use unify_bads function m_unified = unify_bad_channels(raws) assert len(m_unified) == len(raws) # check results - correct_bads = [chns[0], chns[2], chns[1], 'nonsense_ch_name'] + correct_bads = [chns[0], chns[2], chns[1], "nonsense_ch_name"] for i in np.arange(len(m_unified)): - assert m_unified[i].info['bads'] == correct_bads + assert m_unified[i].info["bads"] == correct_bads From 9b745cb6ece743d857fc5692f6223a645b23dd39 Mon Sep 17 00:00:00 2001 From: nordme Date: Wed, 27 Sep 2023 09:54:21 -0700 Subject: [PATCH 15/38] style fix --- mne/preprocessing/tests/test_unify_bads.py | 63 ++++++++++------------ 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/preprocessing/tests/test_unify_bads.py index 96edf86c1ae..da4d01716de 100644 --- a/mne/preprocessing/tests/test_unify_bads.py +++ b/mne/preprocessing/tests/test_unify_bads.py @@ -16,16 +16,18 @@ def test_instance_support(instance, request, evoked): assert inst_out == [inst] -@pytest.mark.parametrize("instance", ([], ["bad_instance"], "mixed")) -def test_error_raising(instance, request, evoked): +def test_error_raising(raw, epochs): """Tests input checking.""" - if instance == "mixed": - instance = [ - _get_inst("raw", request, evoked), - _get_inst("epochs", request, evoked), - ] - with pytest.raises(TypeError): - unify_bad_channels(instance) + with pytest.raises(IndexError, match=r"empty list"): + unify_bad_channels([]) + with pytest.raises(TypeError, match=r"must be an instance of"): + unify_bad_channels(["bad_instance"]) + with pytest.raises(ValueError, match=r"same type"): + unify_bad_channels([raw, epochs]) + with pytest.raises(AssertionError): + raw_alt = raw.copy() + raw_alt.info["ch_names"] = raw.info["ch_names"][0] + unify_bad_channels([raw, raw_alt]) def test_bads_compilation(raw): @@ -33,30 +35,23 @@ def test_bads_compilation(raw): Tests two cases: a) single instance passed to function with an existing bad, and b) multiple instances passed to function with varying compilation - scenarios including empty bads, unique bads, partially duplicated bads - listed out-of-order, and nonsense channel names. + scenarios including empty bads, unique bads, and partially duplicated bads + listed out-of-order. """ - ## check unification scenarios - chns = raw.info["ch_names"][0:3] - # scenario 1: single instance passed with actual bads (already tested no bads) assert raw.info["bads"] == [] - raw.info["bads"] += [chns[0]] - s_unified = unify_bad_channels([raw]) - assert len(s_unified) == 1 - assert s_unified[0].info["bads"] == [chns[0]], (s_unified[0].info["bads"], chns[0]) - # scenario 2: multiple instances passed, bads types as follows: - # a) empty bads list, b) unique bads, c) overlapping out-of-order bads, - # d) channel name not included in raw channels - raws = [raw, raw.copy(), raw.copy(), raw.copy(), raw.copy()] - assert raws[0].info["bads"] == [chns[0]] - raws[1].info["bads"] = [] - raws[2].info["bads"] = [chns[2]] - raws[3].info["bads"] = [chns[2], chns[1]] - raws[4].info["bads"] = ["nonsense_ch_name"] - # use unify_bads function - m_unified = unify_bad_channels(raws) - assert len(m_unified) == len(raws) - # check results - correct_bads = [chns[0], chns[2], chns[1], "nonsense_ch_name"] - for i in np.arange(len(m_unified)): - assert m_unified[i].info["bads"] == correct_bads + chns = raw.info["ch_names"][0:3] + no_bad = raw.copy() + one_bad = raw.copy() + one_bad.info["bads"] = chns[1] + three_bad = raw.copy() + three_bad.info["bads"] = chns + # scenario 1: single instance passed with actual bads + s_out = unify_bad_channels([one_bad]) + assert len(s_out) == 1 + assert s_out[0].info["bads"] == [chns[1]], (s_out[0].info["bads"], chns[1]) + # scenario 2: multiple instances passed + m_out = unify_bad_channels([one_bad, no_bad, three_bad]) + assert len(m_out) == 3 + correct_bads = [chns[1], chns[0], chns[2]] + for i in np.arange(len(m_out)): + assert m_out[i].info["bads"] == correct_bads From b32f318d83548e47c18bb9c6be0fd629a3eac5f8 Mon Sep 17 00:00:00 2001 From: nordme Date: Wed, 27 Sep 2023 10:12:16 -0700 Subject: [PATCH 16/38] fixes --- mne/preprocessing/tests/test_unify_bads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/preprocessing/tests/test_unify_bads.py index da4d01716de..6ca9d60fa72 100644 --- a/mne/preprocessing/tests/test_unify_bads.py +++ b/mne/preprocessing/tests/test_unify_bads.py @@ -18,7 +18,7 @@ def test_instance_support(instance, request, evoked): def test_error_raising(raw, epochs): """Tests input checking.""" - with pytest.raises(IndexError, match=r"empty list"): + with pytest.raises(IndexError, match=r"list index"): unify_bad_channels([]) with pytest.raises(TypeError, match=r"must be an instance of"): unify_bad_channels(["bad_instance"]) @@ -42,7 +42,7 @@ def test_bads_compilation(raw): chns = raw.info["ch_names"][0:3] no_bad = raw.copy() one_bad = raw.copy() - one_bad.info["bads"] = chns[1] + one_bad.info["bads"] = [chns[1]] three_bad = raw.copy() three_bad.info["bads"] = chns # scenario 1: single instance passed with actual bads From 023db18096be813d7470041667847261a171809d Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Wed, 27 Sep 2023 13:17:06 -0400 Subject: [PATCH 17/38] adding ch_name check to function --- mne/preprocessing/bads.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 453287dd311..b566495a79d 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -95,9 +95,22 @@ def unify_bad_channels(insts): # check that input is a list if not isinstance(insts, list): raise ValueError(f"insts must be a *list* of mne objects, got {type(insts)}") - + # check input is not an empty list if len(insts) == 0: raise ValueError("Be sure insts is not empty list") + # check that all channels have the same name and same number + ch_names = insts[0].info.ch_names + diff_chns = [] + for inst in insts[1:]: + if inst.info.ch_names != ch_names: + dif = set(inst.info.ch_names).difference(set(ch_names)) + diff_chns.extend(list(dif)) + + if len(diff_chns) > 0: + raise ValueError( + "Channel names are not consistent across instances." + f" Mismatch channels are {diff_chns}" + ) # then iterate through the insts to gather bads all_bads = dict() From ce03592f0085a2e1d28047992f709696d5d1687c Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Wed, 27 Sep 2023 13:22:06 -0400 Subject: [PATCH 18/38] moving len check to first check --- mne/preprocessing/bads.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index b566495a79d..0e3bc46fd55 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -84,6 +84,10 @@ def unify_bad_channels(insts): .. versionadded:: 1.6 """ + # check input is not an empty list + if len(insts) == 0: + raise ValueError("Be sure insts is not empty list") + # first check that each object is mne object inst_type = type(insts[0]) valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) @@ -95,9 +99,7 @@ def unify_bad_channels(insts): # check that input is a list if not isinstance(insts, list): raise ValueError(f"insts must be a *list* of mne objects, got {type(insts)}") - # check input is not an empty list - if len(insts) == 0: - raise ValueError("Be sure insts is not empty list") + # check that all channels have the same name and same number ch_names = insts[0].info.ch_names diff_chns = [] From dc8a71a174e91c55c9d591f3d4c4f2af3b8abedb Mon Sep 17 00:00:00 2001 From: nordme Date: Wed, 27 Sep 2023 10:23:19 -0700 Subject: [PATCH 19/38] reconcile with Ana --- mne/preprocessing/tests/test_unify_bads.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/preprocessing/tests/test_unify_bads.py index 6ca9d60fa72..d5fe4f2b849 100644 --- a/mne/preprocessing/tests/test_unify_bads.py +++ b/mne/preprocessing/tests/test_unify_bads.py @@ -18,15 +18,15 @@ def test_instance_support(instance, request, evoked): def test_error_raising(raw, epochs): """Tests input checking.""" - with pytest.raises(IndexError, match=r"list index"): + with pytest.raises(ValueError, match=r"empty list"): unify_bad_channels([]) with pytest.raises(TypeError, match=r"must be an instance of"): unify_bad_channels(["bad_instance"]) with pytest.raises(ValueError, match=r"same type"): unify_bad_channels([raw, epochs]) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): raw_alt = raw.copy() - raw_alt.info["ch_names"] = raw.info["ch_names"][0] + raw_alt.drop_channels(raw.info["ch_names"][0]) unify_bad_channels([raw, raw_alt]) From 7e38dd59b452c24aaf2c6b0a279f1e204900d438 Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Wed, 27 Sep 2023 13:50:11 -0400 Subject: [PATCH 20/38] changing ch_name check --- mne/preprocessing/bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 0e3bc46fd55..27b0679f82c 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -105,7 +105,7 @@ def unify_bad_channels(insts): diff_chns = [] for inst in insts[1:]: if inst.info.ch_names != ch_names: - dif = set(inst.info.ch_names).difference(set(ch_names)) + dif = set(inst.info.ch_names) ^ (set(ch_names)) diff_chns.extend(list(dif)) if len(diff_chns) > 0: From 2d4528ed87d52ae984c4e5249ef7840981c55eb3 Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:35:22 -0400 Subject: [PATCH 21/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 27b0679f82c..4a71b4019a1 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -60,7 +60,7 @@ def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): def unify_bad_channels(insts): - """Unify bad channels across list of instances. + """Unify bad channels across a list of instances, using the union. This function looks across the list of instances to gather a list of "bad" channels. Every instance's info["bads"] is set with the same "bad" From a525f33ec943a3946c12214db392b7f33e29b774 Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:35:49 -0400 Subject: [PATCH 22/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 4a71b4019a1..8f7d42655ab 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -62,10 +62,6 @@ def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): def unify_bad_channels(insts): """Unify bad channels across a list of instances, using the union. - This function looks across the list of instances to gather a list of - "bad" channels. Every instance's info["bads"] is set with the same "bad" - channels. - Parameters ---------- insts : list From ccba9359494d762f96445ff977d7a1fc2880c61e Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:36:00 -0400 Subject: [PATCH 23/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 8f7d42655ab..20a5b7f9625 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -67,7 +67,7 @@ def unify_bad_channels(insts): insts : list List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, - :class:`~mne.time_frequency.EpochSpectrum`) to unify bad channels. + :class:`~mne.time_frequency.EpochSpectrum`) across which to unify bad channels. Returns ------- From e0050141deaa8a7293c4312992606f05b0d5d8cf Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:36:15 -0400 Subject: [PATCH 24/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 20a5b7f9625..7ecd52fe04d 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -71,7 +71,7 @@ def unify_bad_channels(insts): Returns ------- - new_insts : list + insts : list List of instances with bad channels unified across instances. Notes From adfb46b84649c786ae2f20074ac53235856091c4 Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:36:29 -0400 Subject: [PATCH 25/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 7ecd52fe04d..85aa455dd63 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -97,7 +97,7 @@ def unify_bad_channels(insts): raise ValueError(f"insts must be a *list* of mne objects, got {type(insts)}") # check that all channels have the same name and same number - ch_names = insts[0].info.ch_names + ch_names = insts[0].ch_names diff_chns = [] for inst in insts[1:]: if inst.info.ch_names != ch_names: From cfa1c740f258395845b731630d58ebb615d655af Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:36:37 -0400 Subject: [PATCH 26/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 85aa455dd63..e7f24086bd6 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -100,7 +100,7 @@ def unify_bad_channels(insts): ch_names = insts[0].ch_names diff_chns = [] for inst in insts[1:]: - if inst.info.ch_names != ch_names: + if inst.ch_names != ch_names: dif = set(inst.info.ch_names) ^ (set(ch_names)) diff_chns.extend(list(dif)) From 3085fa0576bdd55ddab795ad81927edfedf736cb Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:36:59 -0400 Subject: [PATCH 27/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index e7f24086bd6..00df9ba2988 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -113,7 +113,6 @@ def unify_bad_channels(insts): # then iterate through the insts to gather bads all_bads = dict() for inst in insts: - # using dictionary method to remove duplicates & preserve order all_bads.update(dict.fromkeys(inst.info["bads"])) all_bads = list(all_bads) From 12b8e8480319a6a5bf8064f19371bfdaa6503eba Mon Sep 17 00:00:00 2001 From: Ana Radanovic <79697247+anaradanovic@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:37:09 -0400 Subject: [PATCH 28/38] Update mne/preprocessing/bads.py Co-authored-by: Daniel McCloy --- mne/preprocessing/bads.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 00df9ba2988..ea6ba1fc9a7 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -110,7 +110,8 @@ def unify_bad_channels(insts): f" Mismatch channels are {diff_chns}" ) - # then iterate through the insts to gather bads + # collect bads as dict keys so that insertion order is preserved + # then later cast to list. all_bads = dict() for inst in insts: all_bads.update(dict.fromkeys(inst.info["bads"])) From 9cb0cf3a5215b1e087c3e365ca41f4674c2f6496 Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Wed, 27 Sep 2023 16:30:05 -0400 Subject: [PATCH 29/38] further changes on unify_bad_channels --- doc/preprocessing.rst | 1 + mne/channels/__init__.py | 1 + mne/channels/channels.py | 86 +++++++++++++++++++++++++++++++++++ mne/preprocessing/__init__.py | 1 - mne/preprocessing/bads.py | 79 +------------------------------- 5 files changed, 89 insertions(+), 79 deletions(-) diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 7c4e430780f..cbfda1ac49b 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -57,6 +57,7 @@ Projections: get_builtin_ch_adjacencies read_ch_adjacency equalize_channels + unify_bad_channels rename_channels generate_2d_layout make_1020_channel_selections diff --git a/mne/channels/__init__.py b/mne/channels/__init__.py index 13b002e5a59..3591d7aeeb4 100644 --- a/mne/channels/__init__.py +++ b/mne/channels/__init__.py @@ -22,6 +22,7 @@ "_EEG_SELECTIONS", "_divide_to_regions", "get_builtin_ch_adjacencies", + "unify_bad_channels", ], "layout": [ "Layout", diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 0b022569ef5..7967b7633c7 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -5,6 +5,8 @@ # Andrew Dykstra # Teon Brooks # Daniel McCloy +# Ana Radanovic +# Erica Peterson # # License: BSD-3-Clause @@ -206,6 +208,90 @@ def equalize_channels(instances, copy=True, verbose=None): return equalized_instances +def unify_bad_channels(insts): + """Unify bad channels across a list of instances, using the union. + + Parameters + ---------- + insts : list + List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, + :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, + :class:`~mne.time_frequency.EpochSpectrum`) across which to unify bad channels. + + Returns + ------- + insts : list + List of instances with bad channels unified across instances. + + See Also + -------- + mne.channels.equalize_channels + mne.channels.rename_channels + mne.channels.combine_channels + + Notes + ----- + This function modifies the instances in-place. + + .. versionadded:: 1.6 + """ + from ..utils import _validate_type + + from ..io import BaseRaw + from ..epochs import Epochs + from ..evoked import Evoked + from ..time_frequency.spectrum import BaseSpectrum + + # check that input is a list + _validate_type(insts, (list, tuple), "instance type") + + # check input is not an empty list + if len(insts) == 0: + raise ValueError("Be sure insts is not empty list") + + # first check that each object is mne object + inst_type = type(insts[0]) + valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) + for inst in insts: + _validate_type(inst, valid_types, "instance type") + if type(inst) != inst_type: + raise ValueError("All insts must be the same type") + + # check that all channels have the same name and same number + ch_names = insts[0].ch_names + diff_chns = [] + for inst in insts[1:]: + if inst.ch_names != ch_names: + dif = set(inst.ch_names) ^ (set(ch_names)) + diff_chns.extend(list(dif)) + if ch_names.sort() == inst.ch_names.sort(): + raise ValueError( + "Channel names are sorted differently across" + "instances. Please use" + "mne.channels.equalize_channels." + ) + + if len(diff_chns) > 0: + raise ValueError( + "Channel names are not consistent across instances. Be sure" + "consistent naming is used across channels. Mismatched channels" + f"are {diff_chns}" + ) + + # collect bads as dict keys so that insertion order is preserved + # then later cast to list. + all_bads = dict() + for inst in insts: + all_bads.update(dict.fromkeys(inst.info["bads"])) + all_bads = list(all_bads) + + # apply bads set to all instances + for inst in insts: + inst.info["bads"] = all_bads + + return insts + + class ReferenceMixin(MontageMixin): """Mixin class for Raw, Evoked, Epochs.""" diff --git a/mne/preprocessing/__init__.py b/mne/preprocessing/__init__.py index 6d818ac60b0..8358f006e09 100644 --- a/mne/preprocessing/__init__.py +++ b/mne/preprocessing/__init__.py @@ -56,6 +56,5 @@ "interpolate": ["equalize_bads", "interpolate_bridged_electrodes"], "_css": ["cortical_signal_suppression"], "hfc": ["compute_proj_hfc"], - "bads": ["unify_bad_channels"], }, ) diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index ea6ba1fc9a7..8b7e23a3484 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -1,20 +1,11 @@ # Authors: Denis Engemann -# Ana Radanovic -# Erica Peterson +# # License: BSD-3-Clause import numpy as np from scipy.stats import zscore -from ..utils import _validate_type - -from ..io import BaseRaw -from ..epochs import Epochs -from ..evoked import Evoked -from ..time_frequency.spectrum import BaseSpectrum - - def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): """Find outliers based on iterated Z-scoring. @@ -57,71 +48,3 @@ def _find_outliers(X, threshold=3.0, max_iter=2, tail=0): bad_idx = np.where(my_mask)[0] return bad_idx - - -def unify_bad_channels(insts): - """Unify bad channels across a list of instances, using the union. - - Parameters - ---------- - insts : list - List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, - :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, - :class:`~mne.time_frequency.EpochSpectrum`) across which to unify bad channels. - - Returns - ------- - insts : list - List of instances with bad channels unified across instances. - - Notes - ----- - This function operates in-place. - - .. versionadded:: 1.6 - """ - # check input is not an empty list - if len(insts) == 0: - raise ValueError("Be sure insts is not empty list") - - # first check that each object is mne object - inst_type = type(insts[0]) - valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) - for inst in insts: - _validate_type(inst, valid_types, "instance type") - if type(inst) != inst_type: - raise ValueError("All insts must be the same type") - - # check that input is a list - if not isinstance(insts, list): - raise ValueError(f"insts must be a *list* of mne objects, got {type(insts)}") - - # check that all channels have the same name and same number - ch_names = insts[0].ch_names - diff_chns = [] - for inst in insts[1:]: - if inst.ch_names != ch_names: - dif = set(inst.info.ch_names) ^ (set(ch_names)) - diff_chns.extend(list(dif)) - - if len(diff_chns) > 0: - raise ValueError( - "Channel names are not consistent across instances." - f" Mismatch channels are {diff_chns}" - ) - - # collect bads as dict keys so that insertion order is preserved - # then later cast to list. - all_bads = dict() - for inst in insts: - all_bads.update(dict.fromkeys(inst.info["bads"])) - all_bads = list(all_bads) - - # apply bads set to all instances - new_insts = [] - - for inst in insts: - inst.info["bads"] = all_bads - new_insts.append(inst) - - return new_insts From 52c4d21d93b1bb4df41f2c247bd9e6183a7b65e3 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Wed, 27 Sep 2023 16:15:47 -0500 Subject: [PATCH 30/38] Apply suggestions from code review --- mne/channels/channels.py | 48 ++++++++++++++++----------------------- mne/preprocessing/bads.py | 1 - 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 7967b7633c7..18291cdece4 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -209,8 +209,11 @@ def equalize_channels(instances, copy=True, verbose=None): def unify_bad_channels(insts): - """Unify bad channels across a list of instances, using the union. + """Unify bad channels across a list of instances. + All instances must be of the same type and have matching channel names and channel + order. The ``.info["bads"]`` of each instance will be set to the union of + ``.info["bads"]`` across all instances. Parameters ---------- insts : list @@ -235,21 +238,17 @@ def unify_bad_channels(insts): .. versionadded:: 1.6 """ - from ..utils import _validate_type - from ..io import BaseRaw from ..epochs import Epochs from ..evoked import Evoked from ..time_frequency.spectrum import BaseSpectrum - # check that input is a list + # ensure input is list-like _validate_type(insts, (list, tuple), "instance type") - - # check input is not an empty list + # ensure non-empty if len(insts) == 0: raise ValueError("Be sure insts is not empty list") - - # first check that each object is mne object + # ensure all insts are MNE objects, and all the same type inst_type = type(insts[0]) valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) for inst in insts: @@ -257,35 +256,26 @@ def unify_bad_channels(insts): if type(inst) != inst_type: raise ValueError("All insts must be the same type") - # check that all channels have the same name and same number + # ensure all insts have the same channels and channel order ch_names = insts[0].ch_names - diff_chns = [] for inst in insts[1:]: - if inst.ch_names != ch_names: - dif = set(inst.ch_names) ^ (set(ch_names)) - diff_chns.extend(list(dif)) - if ch_names.sort() == inst.ch_names.sort(): - raise ValueError( - "Channel names are sorted differently across" - "instances. Please use" - "mne.channels.equalize_channels." - ) - - if len(diff_chns) > 0: - raise ValueError( - "Channel names are not consistent across instances. Be sure" - "consistent naming is used across channels. Mismatched channels" - f"are {diff_chns}" - ) + dif = set(inst.ch_names) ^ set(ch_names) + if len(dif): + # TODO raise error, we know some chs need to be dropped + raise ValueError("") + elif inst.ch_names != ch_names: + raise ValueError( + "Channel names are sorted differently across instances. Please use " + "mne.channels.equalize_channels." + ) - # collect bads as dict keys so that insertion order is preserved - # then later cast to list. + # collect bads as dict keys so that insertion order is preserved, then cast to list all_bads = dict() for inst in insts: all_bads.update(dict.fromkeys(inst.info["bads"])) all_bads = list(all_bads) - # apply bads set to all instances + # update bads on all instances for inst in insts: inst.info["bads"] = all_bads diff --git a/mne/preprocessing/bads.py b/mne/preprocessing/bads.py index 8b7e23a3484..077145ff7c8 100644 --- a/mne/preprocessing/bads.py +++ b/mne/preprocessing/bads.py @@ -1,5 +1,4 @@ # Authors: Denis Engemann -# # License: BSD-3-Clause import numpy as np From 83407f6c918cb6e651760be062730e503681bda9 Mon Sep 17 00:00:00 2001 From: nordme Date: Wed, 27 Sep 2023 14:19:05 -0700 Subject: [PATCH 31/38] fixes for Ana fixes --- mne/preprocessing/tests/test_unify_bads.py | 40 ++++++++++------------ 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/preprocessing/tests/test_unify_bads.py index d5fe4f2b849..1788d863849 100644 --- a/mne/preprocessing/tests/test_unify_bads.py +++ b/mne/preprocessing/tests/test_unify_bads.py @@ -1,33 +1,26 @@ -import numpy as np import pytest - -from mne.preprocessing import unify_bad_channels -from mne.time_frequency.tests.test_spectrum import _get_inst - - -@pytest.mark.parametrize( - "instance", ("raw", "epochs", "evoked", "raw_spectrum", "epochs_spectrum") -) -def test_instance_support(instance, request, evoked): - """Tests support of different classes.""" - # test unify_bads function on instance (single input, no bads scenario) - inst = _get_inst(instance, request, evoked) - inst_out = unify_bad_channels([inst]) - assert inst_out == [inst] +from mne.channels import unify_bad_channels def test_error_raising(raw, epochs): """Tests input checking.""" + with pytest.raises(TypeError, match=r"must be an instance of list"): + unify_bad_channels("bad input") with pytest.raises(ValueError, match=r"empty list"): unify_bad_channels([]) with pytest.raises(TypeError, match=r"must be an instance of"): unify_bad_channels(["bad_instance"]) with pytest.raises(ValueError, match=r"same type"): unify_bad_channels([raw, epochs]) - with pytest.raises(ValueError): - raw_alt = raw.copy() - raw_alt.drop_channels(raw.info["ch_names"][0]) - unify_bad_channels([raw, raw_alt]) + with pytest.raises(ValueError, match=r"sorted differently"): + raw_alt1 = raw.copy() + raw_alt1.drop_channels(raw.info["ch_names"][-1]) + unify_bad_channels([raw, raw_alt1]) # ch diff preserving order + with pytest.raises(ValueError, match=r"not consistent"): + raw_alt2 = raw.copy() + new_order = [raw.info["ch_names"][-1]] + raw.info["ch_names"][:-1] + raw_alt2.reorder_channels(new_order) + unify_bad_channels([raw, raw_alt2]) def test_bads_compilation(raw): @@ -37,6 +30,9 @@ def test_bads_compilation(raw): bad, and b) multiple instances passed to function with varying compilation scenarios including empty bads, unique bads, and partially duplicated bads listed out-of-order. + + Only the Raw instance type is tested, since bad channel implementation is + controlled across instance types with a MixIn class. """ assert raw.info["bads"] == [] chns = raw.info["ch_names"][0:3] @@ -47,11 +43,11 @@ def test_bads_compilation(raw): three_bad.info["bads"] = chns # scenario 1: single instance passed with actual bads s_out = unify_bad_channels([one_bad]) - assert len(s_out) == 1 + assert len(s_out) == 1, len(s_out) assert s_out[0].info["bads"] == [chns[1]], (s_out[0].info["bads"], chns[1]) # scenario 2: multiple instances passed m_out = unify_bad_channels([one_bad, no_bad, three_bad]) - assert len(m_out) == 3 + assert len(m_out) == 3, len(m_out) correct_bads = [chns[1], chns[0], chns[2]] - for i in np.arange(len(m_out)): + for i in range(len(m_out)): assert m_out[i].info["bads"] == correct_bads From df45b8640d21c0978740bcf67e823e4392b039dc Mon Sep 17 00:00:00 2001 From: nordme Date: Wed, 27 Sep 2023 14:24:23 -0700 Subject: [PATCH 32/38] move tests file --- mne/{preprocessing => channels}/tests/test_unify_bads.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mne/{preprocessing => channels}/tests/test_unify_bads.py (100%) diff --git a/mne/preprocessing/tests/test_unify_bads.py b/mne/channels/tests/test_unify_bads.py similarity index 100% rename from mne/preprocessing/tests/test_unify_bads.py rename to mne/channels/tests/test_unify_bads.py From 8b37d1eea4f0c89129c85d46ff91bfd8e1f75043 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Wed, 27 Sep 2023 16:26:35 -0500 Subject: [PATCH 33/38] Update mne/channels/channels.py --- mne/channels/channels.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 18291cdece4..5cb3d22c3e1 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -214,6 +214,7 @@ def unify_bad_channels(insts): All instances must be of the same type and have matching channel names and channel order. The ``.info["bads"]`` of each instance will be set to the union of ``.info["bads"]`` across all instances. + Parameters ---------- insts : list From b7834d5fb8aadb20a124bfc7f77fcc87300f2b55 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Wed, 27 Sep 2023 16:43:15 -0500 Subject: [PATCH 34/38] Apply more suggestions from code review --- mne/channels/channels.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 5cb3d22c3e1..194d3fa08a8 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -245,15 +245,15 @@ def unify_bad_channels(insts): from ..time_frequency.spectrum import BaseSpectrum # ensure input is list-like - _validate_type(insts, (list, tuple), "instance type") + _validate_type(insts, (list, tuple), "insts") # ensure non-empty if len(insts) == 0: - raise ValueError("Be sure insts is not empty list") + raise ValueError("insts must not be empty") # ensure all insts are MNE objects, and all the same type inst_type = type(insts[0]) valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) for inst in insts: - _validate_type(inst, valid_types, "instance type") + _validate_type(inst, valid_types, "each object in insts") if type(inst) != inst_type: raise ValueError("All insts must be the same type") @@ -262,8 +262,10 @@ def unify_bad_channels(insts): for inst in insts[1:]: dif = set(inst.ch_names) ^ set(ch_names) if len(dif): - # TODO raise error, we know some chs need to be dropped - raise ValueError("") + raise ValueError( + "Channels do not match across the objects in insts. Consider calling " + "equalize_channels before calling this function." + ) elif inst.ch_names != ch_names: raise ValueError( "Channel names are sorted differently across instances. Please use " From 0309925780066f87edc055596e689b34cec26452 Mon Sep 17 00:00:00 2001 From: nordme <38704848+nordme@users.noreply.github.com> Date: Wed, 27 Sep 2023 14:58:06 -0700 Subject: [PATCH 35/38] Apply suggestions from code review Co-authored-by: Daniel McCloy --- mne/channels/tests/test_unify_bads.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mne/channels/tests/test_unify_bads.py b/mne/channels/tests/test_unify_bads.py index 1788d863849..771a52af51c 100644 --- a/mne/channels/tests/test_unify_bads.py +++ b/mne/channels/tests/test_unify_bads.py @@ -8,17 +8,17 @@ def test_error_raising(raw, epochs): unify_bad_channels("bad input") with pytest.raises(ValueError, match=r"empty list"): unify_bad_channels([]) - with pytest.raises(TypeError, match=r"must be an instance of"): + with pytest.raises(TypeError, match=r"each object in insts must be an instance of"): unify_bad_channels(["bad_instance"]) with pytest.raises(ValueError, match=r"same type"): unify_bad_channels([raw, epochs]) - with pytest.raises(ValueError, match=r"sorted differently"): + with pytest.raises(ValueError, match=r"Channels do not match across"): raw_alt1 = raw.copy() raw_alt1.drop_channels(raw.info["ch_names"][-1]) unify_bad_channels([raw, raw_alt1]) # ch diff preserving order - with pytest.raises(ValueError, match=r"not consistent"): + with pytest.raises(ValueError, match=r"sorted differently"): raw_alt2 = raw.copy() - new_order = [raw.info["ch_names"][-1]] + raw.info["ch_names"][:-1] + new_order = [raw.ch_names[-1]] + raw.ch_names[:-1] raw_alt2.reorder_channels(new_order) unify_bad_channels([raw, raw_alt2]) @@ -35,7 +35,7 @@ def test_bads_compilation(raw): controlled across instance types with a MixIn class. """ assert raw.info["bads"] == [] - chns = raw.info["ch_names"][0:3] + chns = raw.ch_names[:3] no_bad = raw.copy() one_bad = raw.copy() one_bad.info["bads"] = [chns[1]] @@ -48,6 +48,6 @@ def test_bads_compilation(raw): # scenario 2: multiple instances passed m_out = unify_bad_channels([one_bad, no_bad, three_bad]) assert len(m_out) == 3, len(m_out) - correct_bads = [chns[1], chns[0], chns[2]] - for i in range(len(m_out)): - assert m_out[i].info["bads"] == correct_bads + expected_order = [chns[1], chns[0], chns[2]] + for inst in m_out: + assert inst.info["bads"] == expected_order From 6c610d1b341abd0ba89b24b3ab46d3503e85c2ea Mon Sep 17 00:00:00 2001 From: nordme Date: Wed, 27 Sep 2023 15:01:58 -0700 Subject: [PATCH 36/38] pytest fix --- mne/channels/tests/test_unify_bads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/channels/tests/test_unify_bads.py b/mne/channels/tests/test_unify_bads.py index 771a52af51c..ac04983802b 100644 --- a/mne/channels/tests/test_unify_bads.py +++ b/mne/channels/tests/test_unify_bads.py @@ -6,7 +6,7 @@ def test_error_raising(raw, epochs): """Tests input checking.""" with pytest.raises(TypeError, match=r"must be an instance of list"): unify_bad_channels("bad input") - with pytest.raises(ValueError, match=r"empty list"): + with pytest.raises(ValueError, match=r"insts must not be empty"): unify_bad_channels([]) with pytest.raises(TypeError, match=r"each object in insts must be an instance of"): unify_bad_channels(["bad_instance"]) From 292c3e400a6af8b20b62ee8fa971e68443144314 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 28 Sep 2023 09:46:45 -0400 Subject: [PATCH 37/38] Update mne/channels/channels.py --- mne/channels/channels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 194d3fa08a8..7c3de44fdd8 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -220,7 +220,7 @@ def unify_bad_channels(insts): insts : list List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, - :class:`~mne.time_frequency.EpochSpectrum`) across which to unify bad channels. + :class:`~mne.time_frequency.EpochsSpectrum`) across which to unify bad channels. Returns ------- From 7f80011f6b94e1e2e2a2ca25f41a425ae84c5e61 Mon Sep 17 00:00:00 2001 From: Ana Radanovic Date: Thu, 28 Sep 2023 12:20:19 -0400 Subject: [PATCH 38/38] tst:ping CIs