diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py index f5e3ad3757f..b75552f657a 100755 --- a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -1,9 +1,9 @@ """ .. _ex-pcaobs: -============================================================================================== -Principal Component Analysis - Optimal Basis Sets (PCA-OBS) for removal of cardiac artefact -============================================================================================== +===================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact +===================================================================================== This script shows an example of how to use an adaptation of PCA-OBS :footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove @@ -24,13 +24,9 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from matplotlib import pyplot as plt -import mne -from mne.preprocessing import find_ecg_events, fix_stim_artifact -from mne.io import read_raw_eeglab -from scipy.signal import firls +import glob + import numpy as np -from mne import Epochs, events_from_annotations, concatenate_raws ############################################################################### # Download sample subject data from OpenNeuro if you haven't already @@ -38,25 +34,67 @@ # median nerve stimulation of the left wrist # Set the target directory to your desired location import openneuro as on -import glob +from matplotlib import pyplot as plt + +import mne +from mne import Epochs, concatenate_raws, events_from_annotations +from mne.io import read_raw_eeglab +from mne.preprocessing import find_ecg_events, fix_stim_artifact # add the path where you want the OpenNeuro data downloaded. Files total around 8 GB # target_dir = "/home/steinnhm/personal/mne-data" -target_dir = '/data/pt_02569/test_data' +target_dir = "/data/pt_02569/test_data" -file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +file_list = glob.glob(target_dir + "/sub-001/eeg/*median*.set") if file_list: - print('Data is already downloaded') + print("Data is already downloaded") else: - on.download(dataset='ds004388', target_dir=target_dir, include='sub-001/*median*_eeg*') + on.download( + dataset="ds004388", target_dir=target_dir, include="sub-001/*median*_eeg*" + ) ############################################################################### # Define the esg channels (arranged in two patches over the neck and lower back) # Also include the ECG channel for artefact correction -esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", - "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", - "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", - "S3", "L4", "S6", "S23", 'ECG'] +esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "L4", + "S6", + "S23", +] # Sampling rate fs = 1000 @@ -73,21 +111,30 @@ # Read in each of the four blocks and concatenate the raw structures after performing # some minimal preprocessing including removing the stimulation artefact, downsampling # and filtering -block_files = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +block_files = glob.glob(target_dir + "/sub-001/eeg/*median*.set") block_files = sorted(block_files) for count, block_file in enumerate(block_files): - raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) + raw = read_raw_eeglab( + block_file, eog=(), preload=True, uint16_codec=None, verbose=None + ) - # Isolate the ESG channels only - raw.pick(esg_chans) + # Isolate the ESG channels (including ECG for R-peak detection) + raw.pick(esg_chans + ["ECG"]) # Find trigger timings to remove the stimulation artefact events, event_dict = events_from_annotations(raw) - trigger_name = 'Median - Stimulation' - - fix_stim_artifact(raw, events=events, event_id=event_dict[trigger_name], tmin=tstart_esg, tmax=tmax_esg, mode='linear', - stim_channel=None) + trigger_name = "Median - Stimulation" + + fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, + ) # Downsample the data raw.resample(fs) @@ -101,20 +148,19 @@ ############################################################################### # Find ECG events and add to the raw structure as event annotations ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG") -ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Samples only +ecg_event_samples = np.asarray( + [[ecg_event[0] for ecg_event in ecg_events]] +) # Samples only -qrs_event_time = [x / fs for x in ecg_event_samples.reshape(-1)] # Divide by sampling rate to make times +qrs_event_time = [ + x / fs for x in ecg_event_samples.reshape(-1) +] # Divide by sampling rate to make times duration = np.repeat(0.0, len(ecg_event_samples)) -description = ['qrs'] * len(ecg_event_samples) - -raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) +description = ["qrs"] * len(ecg_event_samples) -############################################################################### -# Create filter coefficients -a = [0, 0, 1, 1] -f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter -ord = round(3 * fs / 0.5) -fwts = firls(ord + 1, f, a) +raw_concat.annotations.append( + qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) +) ############################################################################### # Create evoked response about the detected R-peaks before cardiac artefact correction @@ -132,13 +178,10 @@ ) evoked_before = epochs.average() -# Apply function - modifies the data in place +# Apply function - modifies the data in place. Optionally high-pass filter +# the data before applying PCA-OBS to remove low frequency drifts mne.preprocessing.apply_pca_obs( - raw_concat, - picks=esg_chans, - n_jobs=5, - qrs=ecg_event_samples, - filter_coords=fwts + raw_concat, picks=esg_chans, n_jobs=5, qrs_indices=ecg_event_samples.reshape(-1) ) epochs = Epochs( @@ -157,8 +200,8 @@ axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") axes.set_ylim([-0.0005, 0.001]) -axes.set_ylabel('Amplitude (V)') -axes.set_xlabel('Time (s)') +axes.set_ylabel("Amplitude (V)") +axes.set_xlabel("Time (s)") axes.set_title("Before (black) vs. After (green)") plt.tight_layout() plt.show() diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index d58c6e77d24..d0a6a1dd742 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -86,8 +86,8 @@ from .maxwell import ( maxwell_filter_prepare_emptyroom, ) from .otp import oversampled_temporal_projection +from .pca_obs import apply_pca_obs from .realign import realign_raw from .ssp import compute_proj_ecg, compute_proj_eog from .stim import fix_stim_artifact from .xdawn import Xdawn -from .pca_obs import apply_pca_obs \ No newline at end of file diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py index fe877f7beb9..0c21dae8158 100755 --- a/mne/preprocessing/pca_obs.py +++ b/mne/preprocessing/pca_obs.py @@ -6,20 +6,16 @@ # Copyright the MNE-Python contributors. import math -from typing import Optional import numpy as np -from scipy.signal import detrend, filtfilt -from sklearn.decomposition import PCA from scipy.interpolate import PchipInterpolator as pchip from scipy.signal import detrend +from sklearn.decomposition import PCA from mne.io.fiff.raw import Raw -from mne.utils import logger, warn +from mne.utils import logger -# TODO: check arguments passed in, raise errors, tests - def fit_ecg_template( data: np.ndarray, pca_template: np.ndarray, @@ -29,105 +25,103 @@ def fit_ecg_template( post_range: int, mid_p: float, fitted_art: np.ndarray, - post_idx_previous_peak: Optional[int], + post_idx_previous_peak: int | None, n_samples_fit: int, ) -> tuple[np.ndarray, int]: """ - Fits the heartbeat artefact found in the data + Fits the heartbeat artefact found in the data. + Returns the fitted artefact and the index of the next peak. Parameters ---------- data (ndarray): Data from the raw signal (n_channels, n_times) - pca_template (ndarray): Mean heartbeat and first N (default 4) + pca_template (ndarray): Mean heartbeat and first N (default 4) principal components of the heartbeat matrix a_peak_idx (int): Sample index of current R-peak peak_range (int): Half the median RR-interval pre_range (int): Number of samples to fit before the R-peak post_range (int): Number of samples to fit after the R-peak - mid_p (float): Sample index marking middle of the median RR interval + mid_p (float): Sample index marking middle of the median RR interval in the signal. Used to extract relevant part of PCA_template. - fitted_art (ndarray): The computed heartbeat artefact computed to + fitted_art (ndarray): The computed heartbeat artefact computed to remove from the data post_idx_previous_peak (optional int): Sample index of previous R-peak - n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. - Helps reduce sharp edges at the end of fitted heartbeat events. + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events Returns ------- tuple[np.ndarray, int]: the fitted artifact and the next peak index """ - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak # Then nextpeak is returned at the end and the process repeats # select window of template - template = pca_template[mid_p - peak_range - 1: mid_p + peak_range + 1, :] + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] # select window of data and detrend it - slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] - detrended_data = detrend(slice.reshape(-1), type="constant") + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") # maps data on template and then maps it again back to the sensor space least_square = np.linalg.lstsq(template, detrended_data, rcond=None) pad_fit = np.dot(template, least_square[0]) # fit artifact - fitted_art[0, a_peak_idx[0] - pre_range - 1: a_peak_idx[0] + post_range] = pad_fit[ - mid_p - pre_range - 1: mid_p + post_range + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range ].T # if last peak, return if post_idx_previous_peak is None: - return fitted_art, a_peak_idx[0] + post_range + return fitted_art, a_peak_idx + post_range # interpolate time between peaks - intpol_window = np.ceil( - [post_idx_previous_peak, a_peak_idx[0] - pre_range] - ).astype(int) # interpolation window + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( + int + ) # interpolation window if intpol_window[0] < intpol_window[1]: # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have x_fit which is two slices on either side of the interpolation window + # endpoints # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange( - intpol_window[0], intpol_window[1] + 1, 1 - ) # points to be interpolated in pt - the gap between the endpoints of the window + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) x_fit = np.concatenate( [ - np.arange( - intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 - ), - np.arange( - intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 - ), + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), ] - ) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] + ) + y_fit = fitted_art[x_fit] y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previous_peak: a_peak_idx[0] - pre_range + 1] = ( - y_interpol - ) + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol - return fitted_art, a_peak_idx[0] + post_range + return fitted_art, a_peak_idx + post_range def apply_pca_obs( - raw: Raw, - picks: list[str], - qrs: np.ndarray, - filter_coords: np.ndarray, + raw: Raw, + picks: list[str], + qrs_indices: np.ndarray, n_components: int = 4, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, ) -> None: """ - Main convenience function for applying the PCA-OBS algorithm - to certain picks of a Raw object. Updates the Raw object in-place. - Makes sanity checks for all inputs. + Apply the PCA-OBS algorithm to picks of a Raw object. + + Update the Raw object in-place. Make sanity checks for all inputs. Parameters ---------- @@ -135,69 +129,59 @@ def apply_pca_obs( The raw data to process picks: list[str] Channels in the Raw object to remove the heart artefact from - qrs: ndarray, shape (n_peaks, 1) - Array of times in (s), of detected R-peaks in ECG channel. - filter_coords: ndarray (N, ) - The numerator coefficient vector of the filter passed to scipy.signal.filtfilt + qrs_indices: ndarray, shape (n_peaks, 1) + Array of indices in the Raw data of detected R-peaks in ECG channel. n_components: int, default 4 Number of PCA components to use to form the OBS n_jobs: int, default None - Number of jobs to perform the PCA-OBS processing in parallel + Number of jobs to perform the PCA-OBS processing in parallel. + Passed on to Raw.apply_function """ - - if not qrs: - raise ValueError("qrs must not be empty") - - if not filter_coords: - raise ValueError("filter_coords must not be empty") + # sanity checks + if not isinstance(qrs_indices, np.ndarray): + raise ValueError("qrs_indices must be an array") + if len(qrs_indices.shape) > 1: + raise ValueError("qrs_indices must be a 1d array") + if qrs_indices.dtype != int: + raise ValueError("qrs_indices must be an array of integers") + if np.any(qrs_indices < 0): + raise ValueError("qrs_indices must be strictly positive integers") + if np.any(qrs_indices >= raw.n_times): + logger.warning("out of bound qrs_indices will be ignored..") + if not picks: + raise ValueError("picks must be a list of channel names") raw.apply_function( _pca_obs, picks=picks, n_jobs=n_jobs, # args sent to PCA_OBS - qrs=qrs, - filter_coords=filter_coords, + qrs=qrs_indices, n_components=n_components, ) + def _pca_obs( data: np.ndarray, qrs: np.ndarray, - filter_coords: np.ndarray, n_components: int, ) -> np.ndarray: - """ - Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) - algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]) - """ - + """Algorithm to remove heart artefact from EEG data (array of length n_times).""" # set to baseline - data = data.reshape(-1, 1) - data = data.T - data = data - np.mean(data, axis=1) + data = data - np.mean(data) - # Allocate memory + # Allocate memory for artifact which will be subtracted from the data fitted_art = np.zeros(data.shape) - peakplot = np.zeros(data.shape) - # Extract QRS events - for idx in qrs[0]: - if idx < len(peakplot[0, :]): - peakplot[0, idx] = 1 # logical indexed locations of qrs events - - peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns - peak_idx = peak_idx.reshape(-1, 1) + # Extract QRS event indexes which are within out data timeframe + peak_idx = qrs[qrs < len(data)] peak_count = len(peak_idx) ################################################################## # Preparatory work - reserving memory, configure sizes, de-trend # ################################################################## - logger.info("Pulse artifact subtraction in progress...Please wait!") - # define peak range based on RR - RR = np.diff(peak_idx[:, 0]) - mRR = np.median(RR) + mRR = np.median(np.diff(peak_idx)) peak_range = round(mRR / 2) # Rounds to an integer mid_p = peak_range + 1 n_samples_fit = round( @@ -205,19 +189,15 @@ def _pca_obs( ) # sample fit for interpolation between fitted artifact windows # make sure array is long enough for PArange (if not cut off last ECG peak) - while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): - peak_count = peak_count - 1 # reduce number of QRS complexes detected - - # Filter channel - eegchan = filtfilt(filter_coords, 1, data) + # NOTE: Here we previously checked for the last part of the window to be big enough. + while peak_idx[peak_count - 1] + peak_range > len(data): + peak_count = peak_count - 1 # reduce number of QRS complexes detected # build PCA matrix(heart-beat-epochs x window-length) pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] # picking out heartbeat epochs for p in range(1, peak_count): - pcamat[p - 1, :] = eegchan[ - 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 - ] + pcamat[p - 1, :] = data[peak_idx[p] - peak_range : peak_idx[p] + peak_range + 1] # detrending matrix(twice) pcamat = detrend( @@ -236,7 +216,7 @@ def _pca_obs( pca.fit(dpcamat) factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) - # define selected number of components using profile likelihood + # define selected number of components using profile likelihood ##################################### # Make template of the ECG artefact # @@ -249,95 +229,87 @@ def _pca_obs( ################ window_start_idx = [] window_end_idx = [] + post_idx_nextPeak = None + for p in range(peak_count): + # if the current peak doesn't have enough data in the + # start of the peak_range, skip fitting the artifact + if peak_idx[p] - peak_range < 0: + continue + # Deals with start portion of data if p == 0: pre_range = peak_range post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) if post_range > peak_range: post_range = peak_range - try: - post_idx_nextPeak = None - fitted_art, post_idx_nextPeak = fit_ecg_template( - data=data, - pca_template=pca_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - # Appending to list instead of using counter - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit first ECG epoch. Reason: {e}") + + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) # Deals with last edge of data - elif p == peak_count-1: - logger.info("On last section - almost there!") - try: - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = peak_range - if pre_range > peak_range: - pre_range = peak_range - fitted_art, _ = fit_ecg_template( - data=data, - pca_template=pca_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit last ECG epoch. Reason: {e}") + elif p == peak_count - 1: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) # Deals with middle portion of data else: - try: - # ---------------- Processing of central data - -------------------- - # cycle through peak artifacts identified by peakplot - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) - if pre_range >= peak_range: - pre_range = peak_range - if post_range > peak_range: - post_range = peak_range - - a_template = pca_template[ - mid_p - peak_range - 1 : mid_p + peak_range + 1, : - ] - fitted_art, post_idx_nextPeak = fit_ecg_template( - data=data, - pca_template=a_template, - a_peak_idx=peak_idx[p], - peak_range=peak_range, - pre_range=pre_range, - post_range=post_range, - mid_p=mid_p, - fitted_art=fitted_art, - post_idx_previous_peak=post_idx_nextPeak, - n_samples_fit=n_samples_fit, - ) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - warn(f"Cannot fit middle section of data. Reason: {e}") + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range - # Actually subtract the artefact, return needs to be the same shape as input data - data = data.reshape(-1) - fitted_art = fitted_art.reshape(-1) + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + # Actually subtract the artefact, return needs to be the same shape as input data data -= fitted_art - data = data.T.reshape(-1) - return data diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py index 06d45062d6d..e2d07a8ce72 100644 --- a/mne/preprocessing/tests/test_pca_obs.py +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -1,44 +1,91 @@ -"""Test the ieeg projection functions.""" - # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# TODO: migrate this structure to test out function +from pathlib import Path +import numpy as np +import pandas as pd import pytest from mne.io import read_raw_fif -from mne.preprocessing.pca_obs import pca_obs -from mne.datasets.testing import data_path, requires_testing_data +from mne.io.fiff.raw import Raw +from mne.preprocessing import apply_pca_obs -# TODO: Where are the test files we want to use located? -fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_path / "test_raw.fif" -@requires_testing_data -@pytest.mark.parametrize( - # TODO: Are there any parameters we can cycle through to - # test multiple? Different fs, windows, highpass freqs, etc.? - # TODO: how do we determine qrs and filter_coords? What are these? - "fs, highpass_freq, qrs, filter_coords", - [ - (0.2, 1.0, 100, 200), - (0.1, 2.0, 100, 200), - ], -) -def test_heart_artifact_removal(fs, highpass_freq, qrs, filter_coords): + +@pytest.fixture() +def short_raw_data(): + """Create a short, picked raw instance.""" + return read_raw_fif(raw_fname, preload=True) + + +def test_heart_artifact_removal(short_raw_data: Raw): """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" - raw = read_raw_fif(fname) + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_indices = np.linspace(0, short_raw_data.n_times, 20, dtype=int)[1:-1] + + # copy the original raw. heart artifact is removed in-place + orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) - # Do something with fs and highpass as processing of the data? - ... + # perform heart artifact removal + apply_pca_obs( + raw=short_raw_data, picks=["eeg"], qrs_indices=ecg_event_indices, n_jobs=1 + ) - # call pca_obs algorithm - result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + # compare processed df to original df + removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() - # assert results - assert result is not None - assert result.shape == (100, 100) - assert result.shape == raw.shape # is this a condition we can test? - assert result[0, 0] == 1.0 - ... \ No newline at end of file + # ensure all column names remain the same + pd.testing.assert_index_equal( + orig_df.columns, + removed_heart_artifact_df.columns, + ) + + # ensure every column starting with EEG has been altered + altered_cols = [c for c in orig_df.columns if c.startswith("EEG")] + for col in altered_cols: + with pytest.raises( + AssertionError + ): # make sure that error is raised when we check equal + pd.testing.assert_series_equal( + orig_df[col], + removed_heart_artifact_df[col], + ) + + # ensure every column not starting with EEG has not been altered + unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")] + pd.testing.assert_frame_equal( + orig_df[unaltered_cols], + removed_heart_artifact_df[unaltered_cols], + ) + + +# test that various nonsensical inputs raise the proper errors +@pytest.mark.parametrize( + ("picks", "qrs", "error"), + [ + (["eeg"], np.array([[0, 1], [2, 3]]), "qrs_indices must be a 1d array"), + (["eeg"], [2, 3, 4], "qrs_indices must be an array"), + ( + ["eeg"], + np.array([None, "foo", 2]), + "qrs_indices must be an array of integers", + ), + ( + ["eeg"], + np.array([-1, 0, 3]), + "qrs_indices must be strictly positive integers", + ), + ([], np.array([1, 2, 3]), "picks must be a list of channel names"), + ], +) +def test_pca_obs_bad_input( + short_raw_data: Raw, picks: list[str], qrs: np.ndarray, error: str +): + """Test if bad input data raises the proper errors in the function sanity checks.""" + with pytest.raises(ValueError, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_indices=qrs)