From e69039b7ab3f22d6648a56b3f51a31fd48614213 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:15:37 +0100 Subject: [PATCH 01/10] test: add initial test structure, missing validation of post-hear-artifact-removed data shapes and values --- mne/preprocessing/tests/test_pca_obs.py | 74 ++++++++++++++++++------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 06d45062d6d..7c67db5899e 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -4,41 +4,75 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# TODO: migrate this structure to test out function +import copy +from pathlib import Path -import pytest +import numpy as np +from scipy.signal import firls +import pytest from mne.io import read_raw_fif -from mne.preprocessing.pca_obs import pca_obs -from mne.datasets.testing import data_path, requires_testing_data -# TODO: Where are the test files we want to use located? -fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" +from mne.preprocessing import apply_pca_obs +from mne.preprocessing.ecg import find_ecg_events + +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_path / "test_raw.fif" + + +@pytest.fixture() +def short_raw_data(): + """Create a short, picked raw instance.""" + return read_raw_fif(raw_fname, preload=True).crop(0, 7) + -@requires_testing_data @pytest.mark.parametrize( # TODO: Are there any parameters we can cycle through to # test multiple? Different fs, windows, highpass freqs, etc.? # TODO: how do we determine qrs and filter_coords? What are these? - "fs, highpass_freq, qrs, filter_coords", + ("fs", "highpass_freq", "qrs", "filter_coords"), [ (0.2, 1.0, 100, 200), (0.1, 2.0, 100, 200), ], ) -def test_heart_artifact_removal(fs, highpass_freq, qrs, filter_coords): +def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - raw = read_raw_fif(fname) - # Do something with fs and highpass as processing of the data? - ... + # get the sampling frequency of the test data and generate the filter coords as in our example + fs = short_raw.info["sfreq"] + a = [0, 0, 1, 1] + f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter + ord = round(3 * fs / 0.5) + filter_coords = firls(ord + 1, f, a) + + # extract the QRS + ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + + # copy the original raw and remove the heart artifact in-place + raw_orig = copy.deepcopy(short_raw) + apply_pca_obs( + raw=short_raw, + picks=["eeg"], + qrs=ecg_event_samples, + filter_coords=filter_coords, + ) + # raw.get_data() ? to get shapes to compare + + assert raw_orig != short_raw + + # # Do something with fs and highpass as processing of the data? + # ... + + # # call pca_obs algorithm + # result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) - # call pca_obs algorithm - result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + # # assert results + # assert result is not None + # assert result.shape == (100, 100) + # assert result.shape == raw.shape # is this a condition we can test? + # assert result[0, 0] == 1.0 - # assert results - assert result is not None - assert result.shape == (100, 100) - assert result.shape == raw.shape # is this a condition we can test? - assert result[0, 0] == 1.0 - ... \ No newline at end of file +if __name__ == "__main__": + pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) \ No newline at end of file From 4bd3a519e36412898b9beea725c70b025d9c2138 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:19:00 +0100 Subject: [PATCH 02/10] style: run pre-commit hooks on test file --- mne/preprocessing/tests/test_pca_obs.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 7c67db5899e..2edb6619c10 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -5,14 +5,13 @@ # Copyright the MNE-Python contributors. import copy -from pathlib import Path +from pathlib import Path import numpy as np +import pytest from scipy.signal import firls -import pytest from mne.io import read_raw_fif - from mne.preprocessing import apply_pca_obs from mne.preprocessing.ecg import find_ecg_events @@ -27,7 +26,7 @@ def short_raw_data(): @pytest.mark.parametrize( - # TODO: Are there any parameters we can cycle through to + # TODO: Are there any parameters we can cycle through to # test multiple? Different fs, windows, highpass freqs, etc.? # TODO: how do we determine qrs and filter_coords? What are these? ("fs", "highpass_freq", "qrs", "filter_coords"), @@ -38,17 +37,17 @@ def short_raw_data(): ) def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - - # get the sampling frequency of the test data and generate the filter coords as in our example + # get the sampling frequency of the test data and + # generate the filter coords as in our example fs = short_raw.info["sfreq"] a = [0, 0, 1, 1] f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - ord = round(3 * fs / 0.5) - filter_coords = firls(ord + 1, f, a) + ord_ = round(3 * fs / 0.5) + filter_coords = firls(ord_ + 1, f, a) # extract the QRS ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # copy the original raw and remove the heart artifact in-place raw_orig = copy.deepcopy(short_raw) @@ -59,11 +58,10 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords filter_coords=filter_coords, ) # raw.get_data() ? to get shapes to compare - + assert raw_orig != short_raw # # Do something with fs and highpass as processing of the data? - # ... # # call pca_obs algorithm # result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) @@ -71,8 +69,9 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords # # assert results # assert result is not None # assert result.shape == (100, 100) - # assert result.shape == raw.shape # is this a condition we can test? + # assert result.shape == raw.shape # is this a condition we can test? # assert result[0, 0] == 1.0 + if __name__ == "__main__": - pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) \ No newline at end of file + pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) From 07458ba8fd7ac23cf5a38696056085f3ab1c4a76 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Mon, 25 Nov 2024 09:23:49 +0100 Subject: [PATCH 03/10] docs: remove duplicated docstring --- mne/preprocessing/pca_obs.py | 59 +++++++++++++++--------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index fe877f7beb9..2eb14a951db 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -6,20 +6,18 @@ # Copyright the MNE-Python contributors. import math -from typing import Optional import numpy as np +from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend, filtfilt from sklearn.decomposition import PCA -from scipy.interpolate import PchipInterpolator as pchip -from scipy.signal import detrend from mne.io.fiff.raw import Raw from mne.utils import logger, warn - # TODO: check arguments passed in, raise errors, tests + def fit_ecg_template( data: np.ndarray, pca_template: np.ndarray, @@ -29,7 +27,7 @@ def fit_ecg_template( post_range: int, mid_p: float, fitted_art: np.ndarray, - post_idx_previous_peak: Optional[int], + post_idx_previous_peak: int | None, n_samples_fit: int, ) -> tuple[np.ndarray, int]: """ @@ -39,15 +37,15 @@ def fit_ecg_template( Parameters ---------- data (ndarray): Data from the raw signal (n_channels, n_times) - pca_template (ndarray): Mean heartbeat and first N (default 4) + pca_template (ndarray): Mean heartbeat and first N (default 4) principal components of the heartbeat matrix a_peak_idx (int): Sample index of current R-peak peak_range (int): Half the median RR-interval pre_range (int): Number of samples to fit before the R-peak post_range (int): Number of samples to fit after the R-peak - mid_p (float): Sample index marking middle of the median RR interval + mid_p (float): Sample index marking middle of the median RR interval in the signal. Used to extract relevant part of PCA_template. - fitted_art (ndarray): The computed heartbeat artefact computed to + fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data post_idx_previous_peak (optional int): Sample index of previous R-peak n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. @@ -57,11 +55,10 @@ def fit_ecg_template( ------- tuple[np.ndarray, int]: the fitted artifact and the next peak index """ - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[mid_p - peak_range - 1: mid_p + peak_range + 1, :] + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] # select window of data and detrend it slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] @@ -72,8 +69,8 @@ def fit_ecg_template( pad_fit = np.dot(template, least_square[0]) # fit artifact - fitted_art[0, a_peak_idx[0] - pre_range - 1: a_peak_idx[0] + post_range] = pad_fit[ - mid_p - pre_range - 1: mid_p + post_range + fitted_art[0, a_peak_idx[0] - pre_range - 1 : a_peak_idx[0] + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range ].T # if last peak, return @@ -81,9 +78,9 @@ def fit_ecg_template( return fitted_art, a_peak_idx[0] + post_range # interpolate time between peaks - intpol_window = np.ceil( - [post_idx_previous_peak, a_peak_idx[0] - pre_range] - ).astype(int) # interpolation window + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx[0] - pre_range]).astype( + int + ) # interpolation window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data @@ -97,19 +94,15 @@ def fit_ecg_template( ) # points to be interpolated in pt - the gap between the endpoints of the window x_fit = np.concatenate( [ - np.arange( - intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 - ), - np.arange( - intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 - ), + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), ] ) # Entire range of x values in this step (taking some number of samples before and after the window) y_fit = fitted_art[0, x_fit] y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previous_peak: a_peak_idx[0] - pre_range + 1] = ( + fitted_art[0, post_idx_previous_peak : a_peak_idx[0] - pre_range + 1] = ( y_interpol ) @@ -117,15 +110,15 @@ def fit_ecg_template( def apply_pca_obs( - raw: Raw, - picks: list[str], - qrs: np.ndarray, + raw: Raw, + picks: list[str], + qrs: np.ndarray, filter_coords: np.ndarray, n_components: int = 4, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, ) -> None: """ - Main convenience function for applying the PCA-OBS algorithm + Main convenience function for applying the PCA-OBS algorithm to certain picks of a Raw object. Updates the Raw object in-place. Makes sanity checks for all inputs. @@ -144,10 +137,9 @@ def apply_pca_obs( n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ - if not qrs: raise ValueError("qrs must not be empty") - + if not filter_coords: raise ValueError("filter_coords must not be empty") @@ -161,6 +153,7 @@ def apply_pca_obs( n_components=n_components, ) + def _pca_obs( data: np.ndarray, qrs: np.ndarray, @@ -168,10 +161,8 @@ def _pca_obs( n_components: int, ) -> np.ndarray: """ - Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) - algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]) + algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]). """ - # set to baseline data = data.reshape(-1, 1) data = data.T @@ -206,7 +197,7 @@ def _pca_obs( # make sure array is long enough for PArange (if not cut off last ECG peak) while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): - peak_count = peak_count - 1 # reduce number of QRS complexes detected + peak_count = peak_count - 1 # reduce number of QRS complexes detected # Filter channel eegchan = filtfilt(filter_coords, 1, data) @@ -277,7 +268,7 @@ def _pca_obs( warn(f"Cannot fit first ECG epoch. Reason: {e}") # Deals with last edge of data - elif p == peak_count-1: + elif p == peak_count - 1: logger.info("On last section - almost there!") try: pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) From a5e68d7ec3bdf3e77dff723a118e9019a8ab20b4 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Wed, 27 Nov 2024 10:26:36 +0100 Subject: [PATCH 04/10] Removed filter_coords from within the method --- .../esg_rm_heart_artefact_pcaobs.py | 10 +-------- mne/preprocessing/pca_obs.py | 21 ++++++------------- mne/preprocessing/tests/test_pca_obs.py | 17 +++++---------- 3 files changed, 12 insertions(+), 36 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index f5e3ad3757f..cd1f2cd48c0 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -109,13 +109,6 @@ raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) -############################################################################### -# Create filter coefficients -a = [0, 0, 1, 1] -f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter -ord = round(3 * fs / 0.5) -fwts = firls(ord + 1, f, a) - ############################################################################### # Create evoked response about the detected R-peaks before cardiac artefact correction # Apply PCA-OBS to remove the cardiac artefact @@ -137,8 +130,7 @@ raw_concat, picks=esg_chans, n_jobs=5, - qrs=ecg_event_samples, - filter_coords=fwts + qrs=ecg_event_samples ) epochs = Epochs( diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 2eb14a951db..54e64e730b2 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -9,7 +9,7 @@ import numpy as np from scipy.interpolate import PchipInterpolator as pchip -from scipy.signal import detrend, filtfilt +from scipy.signal import detrend from sklearn.decomposition import PCA from mne.io.fiff.raw import Raw @@ -113,7 +113,6 @@ def apply_pca_obs( raw: Raw, picks: list[str], qrs: np.ndarray, - filter_coords: np.ndarray, n_components: int = 4, n_jobs: int | None = None, ) -> None: @@ -130,18 +129,15 @@ def apply_pca_obs( Channels in the Raw object to remove the heart artefact from qrs: ndarray, shape (n_peaks, 1) Array of times in (s), of detected R-peaks in ECG channel. - filter_coords: ndarray (N, ) - The numerator coefficient vector of the filter passed to scipy.signal.filtfilt n_components: int, default 4 Number of PCA components to use to form the OBS n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ - if not qrs: - raise ValueError("qrs must not be empty") - - if not filter_coords: - raise ValueError("filter_coords must not be empty") + # TODO: Causes error 'ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()' + # Removed for now + # if not qrs: + # raise ValueError("qrs must not be empty") raw.apply_function( _pca_obs, @@ -149,7 +145,6 @@ def apply_pca_obs( n_jobs=n_jobs, # args sent to PCA_OBS qrs=qrs, - filter_coords=filter_coords, n_components=n_components, ) @@ -157,7 +152,6 @@ def apply_pca_obs( def _pca_obs( data: np.ndarray, qrs: np.ndarray, - filter_coords: np.ndarray, n_components: int, ) -> np.ndarray: """ @@ -199,14 +193,11 @@ def _pca_obs( while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): peak_count = peak_count - 1 # reduce number of QRS complexes detected - # Filter channel - eegchan = filtfilt(filter_coords, 1, data) - # build PCA matrix(heart-beat-epochs x window-length) pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] # picking out heartbeat epochs for p in range(1, peak_count): - pcamat[p - 1, :] = eegchan[ + pcamat[p - 1, :] = data[ 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 ] diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 2edb6619c10..46367d1eb21 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -28,22 +28,16 @@ def short_raw_data(): @pytest.mark.parametrize( # TODO: Are there any parameters we can cycle through to # test multiple? Different fs, windows, highpass freqs, etc.? - # TODO: how do we determine qrs and filter_coords? What are these? - ("fs", "highpass_freq", "qrs", "filter_coords"), + # TODO: how do we determine qrs? What are these? + # QRS is marking the sample index of R-peaks in the signal + ("fs", "highpass_freq", "qrs"), [ (0.2, 1.0, 100, 200), (0.1, 2.0, 100, 200), ], ) -def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords): +def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - # get the sampling frequency of the test data and - # generate the filter coords as in our example - fs = short_raw.info["sfreq"] - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - ord_ = round(3 * fs / 0.5) - filter_coords = firls(ord_ + 1, f, a) # extract the QRS ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) @@ -55,7 +49,6 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords raw=short_raw, picks=["eeg"], qrs=ecg_event_samples, - filter_coords=filter_coords, ) # raw.get_data() ? to get shapes to compare @@ -64,7 +57,7 @@ def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs, filter_coords # # Do something with fs and highpass as processing of the data? # # call pca_obs algorithm - # result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + # result = pca_obs(raw, qrs=qrs) # # assert results # assert result is not None From 1938573956b2db32d9f02d2d27c592bebdfdf6b0 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Wed, 27 Nov 2024 10:43:07 +0100 Subject: [PATCH 05/10] Adding info to filter to example --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index cd1f2cd48c0..52bb3af73dc 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -126,6 +126,7 @@ evoked_before = epochs.average() # Apply function - modifies the data in place +# Optionally high-pass filter the data before applying PCA-OBS to remove low frequency drifts mne.preprocessing.apply_pca_obs( raw_concat, picks=esg_chans, From 19e0802aee20e88dcc1de1502987650a2b17bf16 Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 4 Dec 2024 22:17:02 +0100 Subject: [PATCH 06/10] style: run import sorter pre-commit hook --- mne/preprocessing/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index d58c6e77d24..d0a6a1dd742 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -86,8 +86,8 @@ from .maxwell import ( maxwell_filter_prepare_emptyroom, ) from .otp import oversampled_temporal_projection +from .pca_obs import apply_pca_obs from .realign import realign_raw from .ssp import compute_proj_ecg, compute_proj_eog from .stim import fix_stim_artifact from .xdawn import Xdawn -from .pca_obs import apply_pca_obs \ No newline at end of file From 49a96d09e3c0232d509f639aa57fdcd77450951a Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 4 Dec 2024 22:17:13 +0100 Subject: [PATCH 07/10] refactor,test: migrate data shape to be 1d, add sanity checks for PCA function input, ad tests for copying original data and comparing to data modified in-place, add window size checks and remove generic try-except blocks BREAKING CHANGE --- mne/preprocessing/pca_obs.py | 242 +++++++++++------------- mne/preprocessing/tests/test_pca_obs.py | 75 ++++---- 2 files changed, 146 insertions(+), 171 deletions(-) diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 54e64e730b2..aeb5478e7eb 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -13,9 +13,6 @@ from sklearn.decomposition import PCA from mne.io.fiff.raw import Raw -from mne.utils import logger, warn - -# TODO: check arguments passed in, raise errors, tests def fit_ecg_template( @@ -31,7 +28,8 @@ def fit_ecg_template( n_samples_fit: int, ) -> tuple[np.ndarray, int]: """ - Fits the heartbeat artefact found in the data + Fits the heartbeat artefact found in the data. + Returns the fitted artefact and the index of the next peak. Parameters @@ -48,8 +46,8 @@ def fit_ecg_template( fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data post_idx_previous_peak (optional int): Sample index of previous R-peak - n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. - Helps reduce sharp edges at the end of fitted heartbeat events. + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events Returns ------- @@ -61,52 +59,55 @@ def fit_ecg_template( template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] # select window of data and detrend it - slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] - detrended_data = detrend(slice.reshape(-1), type="constant") + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") # maps data on template and then maps it again back to the sensor space least_square = np.linalg.lstsq(template, detrended_data, rcond=None) pad_fit = np.dot(template, least_square[0]) # fit artifact - fitted_art[0, a_peak_idx[0] - pre_range - 1 : a_peak_idx[0] + post_range] = pad_fit[ + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ mid_p - pre_range - 1 : mid_p + post_range ].T # if last peak, return if post_idx_previous_peak is None: - return fitted_art, a_peak_idx[0] + post_range + return fitted_art, a_peak_idx + post_range # interpolate time between peaks - intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx[0] - pre_range]).astype( + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( int ) # interpolation window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have x_fit which is two slices on either side of the interpolation window + # endpoints # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated in pt - the gap between the endpoints of the window + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) x_fit = np.concatenate( [ np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), ] - ) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] + ) + y_fit = fitted_art[x_fit] y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previous_peak : a_peak_idx[0] - pre_range + 1] = ( - y_interpol - ) + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol - return fitted_art, a_peak_idx[0] + post_range + return fitted_art, a_peak_idx + post_range def apply_pca_obs( @@ -117,9 +118,9 @@ def apply_pca_obs( n_jobs: int | None = None, ) -> None: """ - Main convenience function for applying the PCA-OBS algorithm - to certain picks of a Raw object. Updates the Raw object in-place. - Makes sanity checks for all inputs. + Apply the PCA-OBS algorithm to picks of a Raw object. + + Update the Raw object in-place. Make sanity checks for all inputs. Parameters ---------- @@ -134,10 +135,13 @@ def apply_pca_obs( n_jobs: int, default None Number of jobs to perform the PCA-OBS processing in parallel """ - # TODO: Causes error 'ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()' - # Removed for now - # if not qrs: - # raise ValueError("qrs must not be empty") + # sanity checks + if len(qrs.shape) > 1: + raise ValueError("qrs must be a 1d array") + if not isinstance(n_jobs, int) or n_jobs < 1: + raise ValueError("n_jobs must be an integer greater than 0") + if not picks: + raise ValueError("picks must be a list of channel names") raw.apply_function( _pca_obs, @@ -154,35 +158,22 @@ def _pca_obs( qrs: np.ndarray, n_components: int, ) -> np.ndarray: - """ - algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]). - """ + """Algorithm to remove heart artefact from EEG data (array of length n_times).""" # set to baseline - data = data.reshape(-1, 1) - data = data.T - data = data - np.mean(data, axis=1) + data = data - np.mean(data) - # Allocate memory + # Allocate memory for artifact which will be subtracted from the data fitted_art = np.zeros(data.shape) - peakplot = np.zeros(data.shape) - - # Extract QRS events - for idx in qrs[0]: - if idx < len(peakplot[0, :]): - peakplot[0, idx] = 1 # logical indexed locations of qrs events - peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns - peak_idx = peak_idx.reshape(-1, 1) + # Extract QRS event indexes which are within out data timeframe + peak_idx = qrs[qrs < len(data)] peak_count = len(peak_idx) ################################################################## # Preparatory work - reserving memory, configure sizes, de-trend # ################################################################## - logger.info("Pulse artifact subtraction in progress...Please wait!") - # define peak range based on RR - RR = np.diff(peak_idx[:, 0]) - mRR = np.median(RR) + mRR = np.median(np.diff(peak_idx)) peak_range = round(mRR / 2) # Rounds to an integer mid_p = peak_range + 1 n_samples_fit = round( @@ -190,16 +181,15 @@ def _pca_obs( ) # sample fit for interpolation between fitted artifact windows # make sure array is long enough for PArange (if not cut off last ECG peak) - while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): + # NOTE: Here we previously checked for the last part of the window to be big enough. + while peak_idx[peak_count - 1] + peak_range > len(data): peak_count = peak_count - 1 # reduce number of QRS complexes detected # build PCA matrix(heart-beat-epochs x window-length) pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] # picking out heartbeat epochs for p in range(1, peak_count): - pcamat[p - 1, :] = data[ - 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 - ] + pcamat[p - 1, :] = data[peak_idx[p] - peak_range : peak_idx[p] + peak_range + 1] # detrending matrix(twice) pcamat = detrend( @@ -218,7 +208,7 @@ def _pca_obs( pca.fit(dpcamat) factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) - # define selected number of components using profile likelihood + # define selected number of components using profile likelihood ##################################### # Make template of the ECG artefact # @@ -231,95 +221,87 @@ def _pca_obs( ################ window_start_idx = [] window_end_idx = [] + post_idx_nextPeak = None + for p in range(peak_count): + # if the current peak doesn't have enough data in the + # start of the peak_range, skip fitting the artifact + if peak_idx[p] - peak_range < 0: + continue + # Deals with start portion of data if p == 0: pre_range = peak_range post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) if post_range > peak_range: post_range = peak_range - try: - post_idx_nextPeak = None - fitted_art, post_idx_nextPeak = fit_ecg_template( - data=data, - pca_template=pca_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - # Appending to list instead of using counter - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit first ECG epoch. Reason: {e}") + + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) # Deals with last edge of data elif p == peak_count - 1: - logger.info("On last section - almost there!") - try: - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = peak_range - if pre_range > peak_range: - pre_range = peak_range - fitted_art, _ = fit_ecg_template( - data=data, - pca_template=pca_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit last ECG epoch. Reason: {e}") + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) # Deals with middle portion of data else: - try: - # ---------------- Processing of central data - -------------------- - # cycle through peak artifacts identified by peakplot - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) - if pre_range >= peak_range: - pre_range = peak_range - if post_range > peak_range: - post_range = peak_range - - a_template = pca_template[ - mid_p - peak_range - 1 : mid_p + peak_range + 1, : - ] - fitted_art, post_idx_nextPeak = fit_ecg_template( - data=data, - pca_template=a_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit middle section of data. Reason: {e}") + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range - # Actually subtract the artefact, return needs to be the same shape as input data - data = data.reshape(-1) - fitted_art = fitted_art.reshape(-1) + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + # Actually subtract the artefact, return needs to be the same shape as input data data -= fitted_art - data = data.T.reshape(-1) - return data diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 46367d1eb21..82b1efa6ebf 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -4,16 +4,15 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import copy from pathlib import Path import numpy as np +import pandas as pd import pytest -from scipy.signal import firls from mne.io import read_raw_fif +from mne.io.fiff.raw import Raw from mne.preprocessing import apply_pca_obs -from mne.preprocessing.ecg import find_ecg_events data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" @@ -22,49 +21,43 @@ @pytest.fixture() def short_raw_data(): """Create a short, picked raw instance.""" - return read_raw_fif(raw_fname, preload=True).crop(0, 7) + return read_raw_fif(raw_fname, preload=True) -@pytest.mark.parametrize( - # TODO: Are there any parameters we can cycle through to - # test multiple? Different fs, windows, highpass freqs, etc.? - # TODO: how do we determine qrs? What are these? - # QRS is marking the sample index of R-peaks in the signal - ("fs", "highpass_freq", "qrs"), - [ - (0.2, 1.0, 100, 200), - (0.1, 2.0, 100, 200), - ], -) -def test_heart_artifact_removal(short_raw, fs, highpass_freq, qrs): +def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + # fake some random qrs events + ecg_event_samples = np.arange(0, len(short_raw_data.times), 1400) + 1430 - # extract the QRS - ecg_events, _, _ = find_ecg_events(short_raw, ch_name=None) - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) + # copy the original raw. heart artifact is removed in-place + orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) - # copy the original raw and remove the heart artifact in-place - raw_orig = copy.deepcopy(short_raw) - apply_pca_obs( - raw=short_raw, - picks=["eeg"], - qrs=ecg_event_samples, - ) - # raw.get_data() ? to get shapes to compare - - assert raw_orig != short_raw - - # # Do something with fs and highpass as processing of the data? + # perform heart artifact removal + apply_pca_obs(raw=short_raw_data, picks=["eeg"], qrs=ecg_event_samples, n_jobs=1) - # # call pca_obs algorithm - # result = pca_obs(raw, qrs=qrs) - - # # assert results - # assert result is not None - # assert result.shape == (100, 100) - # assert result.shape == raw.shape # is this a condition we can test? - # assert result[0, 0] == 1.0 + # compare processed df to original df + removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() + # ensure all column names remain the same + pd.testing.assert_index_equal( + orig_df.columns, + removed_heart_artifact_df.columns, + ) -if __name__ == "__main__": - pytest.main(["mne/preprocessing/tests/test_pca_obs.py"]) + # ensure every column starting with EEG has been altered + altered_cols = [c for c in orig_df.columns if c.startswith("EEG")] + for col in altered_cols: + with pytest.raises( + AssertionError + ): # make sure that error is raised when we check equal + pd.testing.assert_series_equal( + orig_df[col], + removed_heart_artifact_df[col], + ) + + # ensure every column not starting with EEG has not been altered + unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")] + pd.testing.assert_frame_equal( + orig_df[unaltered_cols], + removed_heart_artifact_df[unaltered_cols], + ) From 925b2931235506cfe1fdad1127280aa8fa518719 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 6 Dec 2024 16:23:00 +0100 Subject: [PATCH 08/10] example:Reshape ecg_events before pca_obs, tests:ecg_events is in samples for algorith --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 2 +- mne/preprocessing/pca_obs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 52bb3af73dc..773cb7410ed 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -131,7 +131,7 @@ raw_concat, picks=esg_chans, n_jobs=5, - qrs=ecg_event_samples + qrs=ecg_event_samples.reshape(-1) ) epochs = Epochs( diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index aeb5478e7eb..05320cf22ae 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -129,7 +129,7 @@ def apply_pca_obs( picks: list[str] Channels in the Raw object to remove the heart artefact from qrs: ndarray, shape (n_peaks, 1) - Array of times in (s), of detected R-peaks in ECG channel. + Array of times in (sample indices), of detected R-peaks in ECG channel. n_components: int, default 4 Number of PCA components to use to form the OBS n_jobs: int, default None From a3d11237a6a7c56e9d3275f5ca75fc8b27e70446 Mon Sep 17 00:00:00 2001 From: Emma Bailey Date: Fri, 6 Dec 2024 21:50:08 +0100 Subject: [PATCH 09/10] example: update channel selection --- examples/preprocessing/esg_rm_heart_artefact_pcaobs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 773cb7410ed..6a66a2057cb 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -56,7 +56,7 @@ esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", - "S3", "L4", "S6", "S23", 'ECG'] + "S3", "L4", "S6", "S23"] # Sampling rate fs = 1000 @@ -79,8 +79,8 @@ for count, block_file in enumerate(block_files): raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) - # Isolate the ESG channels only - raw.pick(esg_chans) + # Isolate the ESG channels (including ECG for R-peak detection) + raw.pick(esg_chans + ['ECG']) # Find trigger timings to remove the stimulation artefact events, event_dict = events_from_annotations(raw) From 2ae84b0f5a69a10be243ead7a54c8cad4431161e Mon Sep 17 00:00:00 2001 From: Steinn Magnusson Date: Wed, 18 Dec 2024 14:29:28 +0100 Subject: [PATCH 10/10] refactor,test: change public qrs kwarg to be more clear about being indices, add sanity checks for input values, add negative-test which verifies proper exceptions when bad data is passed to function --- .../esg_rm_heart_artefact_pcaobs.py | 124 ++++++++++++------ mne/preprocessing/pca_obs.py | 26 ++-- mne/preprocessing/tests/test_pca_obs.py | 38 +++++- 3 files changed, 137 insertions(+), 51 deletions(-) diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index 6a66a2057cb..b75552f657a 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -1,9 +1,9 @@ """ .. _ex-pcaobs: -============================================================================================== -Principal Component Analysis - Optimal Basis Sets (PCA-OBS) for removal of cardiac artefact -============================================================================================== +===================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact +===================================================================================== This script shows an example of how to use an adaptation of PCA-OBS :footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove @@ -24,13 +24,9 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from matplotlib import pyplot as plt -import mne -from mne.preprocessing import find_ecg_events, fix_stim_artifact -from mne.io import read_raw_eeglab -from scipy.signal import firls +import glob + import numpy as np -from mne import Epochs, events_from_annotations, concatenate_raws ############################################################################### # Download sample subject data from OpenNeuro if you haven't already @@ -38,25 +34,67 @@ # median nerve stimulation of the left wrist # Set the target directory to your desired location import openneuro as on -import glob +from matplotlib import pyplot as plt + +import mne +from mne import Epochs, concatenate_raws, events_from_annotations +from mne.io import read_raw_eeglab +from mne.preprocessing import find_ecg_events, fix_stim_artifact # add the path where you want the OpenNeuro data downloaded. Files total around 8 GB # target_dir = "/home/steinnhm/personal/mne-data" -target_dir = '/data/pt_02569/test_data' +target_dir = "/data/pt_02569/test_data" -file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +file_list = glob.glob(target_dir + "/sub-001/eeg/*median*.set") if file_list: - print('Data is already downloaded') + print("Data is already downloaded") else: - on.download(dataset='ds004388', target_dir=target_dir, include='sub-001/*median*_eeg*') + on.download( + dataset="ds004388", target_dir=target_dir, include="sub-001/*median*_eeg*" + ) ############################################################################### # Define the esg channels (arranged in two patches over the neck and lower back) # Also include the ECG channel for artefact correction -esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", - "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", - "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", - "S3", "L4", "S6", "S23"] +esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "L4", + "S6", + "S23", +] # Sampling rate fs = 1000 @@ -73,21 +111,30 @@ # Read in each of the four blocks and concatenate the raw structures after performing # some minimal preprocessing including removing the stimulation artefact, downsampling # and filtering -block_files = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +block_files = glob.glob(target_dir + "/sub-001/eeg/*median*.set") block_files = sorted(block_files) for count, block_file in enumerate(block_files): - raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) + raw = read_raw_eeglab( + block_file, eog=(), preload=True, uint16_codec=None, verbose=None + ) # Isolate the ESG channels (including ECG for R-peak detection) - raw.pick(esg_chans + ['ECG']) + raw.pick(esg_chans + ["ECG"]) # Find trigger timings to remove the stimulation artefact events, event_dict = events_from_annotations(raw) - trigger_name = 'Median - Stimulation' - - fix_stim_artifact(raw, events=events, event_id=event_dict[trigger_name], tmin=tstart_esg, tmax=tmax_esg, mode='linear', - stim_channel=None) + trigger_name = "Median - Stimulation" + + fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, + ) # Downsample the data raw.resample(fs) @@ -101,13 +148,19 @@ ############################################################################### # Find ECG events and add to the raw structure as event annotations ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG") -ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Samples only +ecg_event_samples = np.asarray( + [[ecg_event[0] for ecg_event in ecg_events]] +) # Samples only -qrs_event_time = [x / fs for x in ecg_event_samples.reshape(-1)] # Divide by sampling rate to make times +qrs_event_time = [ + x / fs for x in ecg_event_samples.reshape(-1) +] # Divide by sampling rate to make times duration = np.repeat(0.0, len(ecg_event_samples)) -description = ['qrs'] * len(ecg_event_samples) +description = ["qrs"] * len(ecg_event_samples) -raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) +raw_concat.annotations.append( + qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) +) ############################################################################### # Create evoked response about the detected R-peaks before cardiac artefact correction @@ -125,13 +178,10 @@ ) evoked_before = epochs.average() -# Apply function - modifies the data in place -# Optionally high-pass filter the data before applying PCA-OBS to remove low frequency drifts +# Apply function - modifies the data in place. Optionally high-pass filter +# the data before applying PCA-OBS to remove low frequency drifts mne.preprocessing.apply_pca_obs( - raw_concat, - picks=esg_chans, - n_jobs=5, - qrs=ecg_event_samples.reshape(-1) + raw_concat, picks=esg_chans, n_jobs=5, qrs_indices=ecg_event_samples.reshape(-1) ) epochs = Epochs( @@ -150,8 +200,8 @@ axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") axes.set_ylim([-0.0005, 0.001]) -axes.set_ylabel('Amplitude (V)') -axes.set_xlabel('Time (s)') +axes.set_ylabel("Amplitude (V)") +axes.set_xlabel("Time (s)") axes.set_title("Before (black) vs. After (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index 05320cf22ae..0c21dae8158 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -13,6 +13,7 @@ from sklearn.decomposition import PCA from mne.io.fiff.raw import Raw +from mne.utils import logger def fit_ecg_template( @@ -113,7 +114,7 @@ def fit_ecg_template( def apply_pca_obs( raw: Raw, picks: list[str], - qrs: np.ndarray, + qrs_indices: np.ndarray, n_components: int = 4, n_jobs: int | None = None, ) -> None: @@ -128,18 +129,25 @@ def apply_pca_obs( The raw data to process picks: list[str] Channels in the Raw object to remove the heart artefact from - qrs: ndarray, shape (n_peaks, 1) - Array of times in (sample indices), of detected R-peaks in ECG channel. + qrs_indices: ndarray, shape (n_peaks, 1) + Array of indices in the Raw data of detected R-peaks in ECG channel. n_components: int, default 4 Number of PCA components to use to form the OBS n_jobs: int, default None - Number of jobs to perform the PCA-OBS processing in parallel + Number of jobs to perform the PCA-OBS processing in parallel. + Passed on to Raw.apply_function """ # sanity checks - if len(qrs.shape) > 1: - raise ValueError("qrs must be a 1d array") - if not isinstance(n_jobs, int) or n_jobs < 1: - raise ValueError("n_jobs must be an integer greater than 0") + if not isinstance(qrs_indices, np.ndarray): + raise ValueError("qrs_indices must be an array") + if len(qrs_indices.shape) > 1: + raise ValueError("qrs_indices must be a 1d array") + if qrs_indices.dtype != int: + raise ValueError("qrs_indices must be an array of integers") + if np.any(qrs_indices < 0): + raise ValueError("qrs_indices must be strictly positive integers") + if np.any(qrs_indices >= raw.n_times): + logger.warning("out of bound qrs_indices will be ignored..") if not picks: raise ValueError("picks must be a list of channel names") @@ -148,7 +156,7 @@ def apply_pca_obs( picks=picks, n_jobs=n_jobs, # args sent to PCA_OBS - qrs=qrs, + qrs=qrs_indices, n_components=n_components, ) diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 82b1efa6ebf..e2d07a8ce72 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -1,5 +1,3 @@ -"""Test the ieeg projection functions.""" - # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -26,14 +24,17 @@ def short_raw_data(): def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - # fake some random qrs events - ecg_event_samples = np.arange(0, len(short_raw_data.times), 1400) + 1430 + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_indices = np.linspace(0, short_raw_data.n_times, 20, dtype=int)[1:-1] # copy the original raw. heart artifact is removed in-place orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) # perform heart artifact removal - apply_pca_obs(raw=short_raw_data, picks=["eeg"], qrs=ecg_event_samples, n_jobs=1) + apply_pca_obs( + raw=short_raw_data, picks=["eeg"], qrs_indices=ecg_event_indices, n_jobs=1 + ) # compare processed df to original df removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() @@ -61,3 +62,30 @@ def test_heart_artifact_removal(short_raw_data: Raw): orig_df[unaltered_cols], removed_heart_artifact_df[unaltered_cols], ) + + +# test that various nonsensical inputs raise the proper errors +@pytest.mark.parametrize( + ("picks", "qrs", "error"), + [ + (["eeg"], np.array([[0, 1], [2, 3]]), "qrs_indices must be a 1d array"), + (["eeg"], [2, 3, 4], "qrs_indices must be an array"), + ( + ["eeg"], + np.array([None, "foo", 2]), + "qrs_indices must be an array of integers", + ), + ( + ["eeg"], + np.array([-1, 0, 3]), + "qrs_indices must be strictly positive integers", + ), + ([], np.array([1, 2, 3]), "picks must be a list of channel names"), + ], +) +def test_pca_obs_bad_input( + short_raw_data: Raw, picks: list[str], qrs: np.ndarray, error: str +): + """Test if bad input data raises the proper errors in the function sanity checks.""" + with pytest.raises(ValueError, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_indices=qrs)