diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..e2578ed18f2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -1335,6 +1335,16 @@ @inproceedings{NdiayeEtAl2016 year = {2016} } +@article{NiazyEtAl2005, + author = {Niazy, R. K. and Beckmann, C.F. and Iannetti, G.D. and Brady, J. M. and Smith, S. M.}, + title = {Removal of FMRI environment artifacts from EEG data using optimal basis sets}, + journal = {NeuroImage}, + year = {2005}, + volume = {28}, + pages = {720-737}, + doi = {10.1016/j.neuroimage.2005.06.067.} +} + @article{NicholsHolmes2002, author = {Nichols, Thomas E. and Holmes, Andrew P.}, doi = {10.1002/hbm.1058}, diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py new file mode 100755 index 00000000000..f5e3ad3757f --- /dev/null +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -0,0 +1,169 @@ +""" +.. _ex-pcaobs: + +============================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) for removal of cardiac artefact +============================================================================================== + +This script shows an example of how to use an adaptation of PCA-OBS +:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove +the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it +has been adapted to remove the delay between the detected R-peak and the +ballistocardiographic artefact such that the algorithm can be applied to +remove the cardiac artefact in EEG (electroencephalogrpahy) and ESG +(electrospinography) data. We will illustrate how it works by applying the +algorithm to ESG data, where the effect of removal is most pronounced. + +See: https://www.biorxiv.org/content/10.1101/2024.09.05.611423v1 +for more details on the dataset and application for ESG data. + +""" + +# Authors: Emma Bailey , +# Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from matplotlib import pyplot as plt +import mne +from mne.preprocessing import find_ecg_events, fix_stim_artifact +from mne.io import read_raw_eeglab +from scipy.signal import firls +import numpy as np +from mne import Epochs, events_from_annotations, concatenate_raws + +############################################################################### +# Download sample subject data from OpenNeuro if you haven't already +# This will download simultaneous EEG and ESG data from a single participant after +# median nerve stimulation of the left wrist +# Set the target directory to your desired location +import openneuro as on +import glob + +# add the path where you want the OpenNeuro data downloaded. Files total around 8 GB +# target_dir = "/home/steinnhm/personal/mne-data" +target_dir = '/data/pt_02569/test_data' + +file_list = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +if file_list: + print('Data is already downloaded') +else: + on.download(dataset='ds004388', target_dir=target_dir, include='sub-001/*median*_eeg*') + +############################################################################### +# Define the esg channels (arranged in two patches over the neck and lower back) +# Also include the ECG channel for artefact correction +esg_chans = ["S35", "S24", "S36", "Iz", "S17", "S15", "S32", "S22", "S19", "S26", "S28", + "S9", "S13", "S11", "S7", "SC1", "S4", "S18", "S8", "S31", "SC6", "S12", + "S16", "S5", "S30", "S20", "S34", "S21", "S25", "L1", "S29", "S14", "S33", + "S3", "L4", "S6", "S23", 'ECG'] + +# Sampling rate +fs = 1000 + +# Interpolation window for ESG data to remove stimulation artefact +tstart_esg = -0.007 +tmax_esg = 0.007 + +# Define timing of heartbeat epochs +iv_baseline = [-400 / 1000, -300 / 1000] +iv_epoch = [-400 / 1000, 600 / 1000] + +############################################################################### +# Read in each of the four blocks and concatenate the raw structures after performing +# some minimal preprocessing including removing the stimulation artefact, downsampling +# and filtering +block_files = glob.glob(target_dir + '/sub-001/eeg/*median*.set') +block_files = sorted(block_files) + +for count, block_file in enumerate(block_files): + raw = read_raw_eeglab(block_file, eog=(), preload=True, uint16_codec=None, verbose=None) + + # Isolate the ESG channels only + raw.pick(esg_chans) + + # Find trigger timings to remove the stimulation artefact + events, event_dict = events_from_annotations(raw) + trigger_name = 'Median - Stimulation' + + fix_stim_artifact(raw, events=events, event_id=event_dict[trigger_name], tmin=tstart_esg, tmax=tmax_esg, mode='linear', + stim_channel=None) + + # Downsample the data + raw.resample(fs) + + # Append blocks of the same condition + if count == 0: + raw_concat = raw + else: + concatenate_raws([raw_concat, raw]) + +############################################################################### +# Find ECG events and add to the raw structure as event annotations +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw_concat, ch_name="ECG") +ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) # Samples only + +qrs_event_time = [x / fs for x in ecg_event_samples.reshape(-1)] # Divide by sampling rate to make times +duration = np.repeat(0.0, len(ecg_event_samples)) +description = ['qrs'] * len(ecg_event_samples) + +raw_concat.annotations.append(qrs_event_time, duration, description, ch_names=[esg_chans]*len(qrs_event_time)) + +############################################################################### +# Create filter coefficients +a = [0, 0, 1, 1] +f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter +ord = round(3 * fs / 0.5) +fwts = firls(ord + 1, f, a) + +############################################################################### +# Create evoked response about the detected R-peaks before cardiac artefact correction +# Apply PCA-OBS to remove the cardiac artefact +# Create evoked response about the detected R-peaks after cardiac artefact correction +events, event_ids = events_from_annotations(raw_concat) +event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} +epochs = Epochs( + raw_concat, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_before = epochs.average() + +# Apply function - modifies the data in place +mne.preprocessing.apply_pca_obs( + raw_concat, + picks=esg_chans, + n_jobs=5, + qrs=ecg_event_samples, + filter_coords=fwts +) + +epochs = Epochs( + raw_concat, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_after = epochs.average() + +############################################################################### +# Comparison image +fig, axes = plt.subplots(1, 1) +axes.plot(evoked_before.times, evoked_before.get_data().T, color="black") +axes.plot(evoked_after.times, evoked_after.get_data().T, color="green") +axes.set_ylim([-0.0005, 0.001]) +axes.set_ylabel('Amplitude (V)') +axes.set_xlabel('Time (s)') +axes.set_title("Before (black) vs. After (green)") +plt.tight_layout() +plt.show() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 54f1c825c13..d58c6e77d24 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "realign_raw", "regress_artifact", "write_fine_calibration", + "apply_pca_obs", ] from . import eyetracking, ieeg, nirs from ._annotate_amplitude import annotate_amplitude @@ -89,3 +90,4 @@ from .realign import realign_raw from .ssp import compute_proj_ecg, compute_proj_eog from .stim import fix_stim_artifact from .xdawn import Xdawn +from .pca_obs import apply_pca_obs \ No newline at end of file diff --git a/mne/preprocessing/pca_obs.py b/mne/preprocessing/pca_obs.py new file mode 100755 index 00000000000..fe877f7beb9 --- /dev/null +++ b/mne/preprocessing/pca_obs.py @@ -0,0 +1,343 @@ +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" + +# Authors: Emma Bailey , +# Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import math +from typing import Optional + +import numpy as np +from scipy.signal import detrend, filtfilt +from sklearn.decomposition import PCA +from scipy.interpolate import PchipInterpolator as pchip +from scipy.signal import detrend + +from mne.io.fiff.raw import Raw +from mne.utils import logger, warn + + +# TODO: check arguments passed in, raise errors, tests + +def fit_ecg_template( + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: Optional[int], + n_samples_fit: int, +) -> tuple[np.ndarray, int]: + """ + Fits the heartbeat artefact found in the data + Returns the fitted artefact and the index of the next peak. + + Parameters + ---------- + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix + a_peak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data + post_idx_previous_peak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation between fitted artifact windows. + Helps reduce sharp edges at the end of fitted heartbeat events. + + Returns + ------- + tuple[np.ndarray, int]: the fitted artifact and the next peak index + """ + + # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previous_peak + # Then nextpeak is returned at the end and the process repeats + # select window of template + template = pca_template[mid_p - peak_range - 1: mid_p + peak_range + 1, :] + + # select window of data and detrend it + slice = data[0, a_peak_idx[0] - peak_range : a_peak_idx[0] + peak_range + 1] + detrended_data = detrend(slice.reshape(-1), type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact + fitted_art[0, a_peak_idx[0] - pre_range - 1: a_peak_idx[0] + post_range] = pad_fit[ + mid_p - pre_range - 1: mid_p + post_range + ].T + + # if last peak, return + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx[0] + post_range + + # interpolate time between peaks + intpol_window = np.ceil( + [post_idx_previous_peak, a_peak_idx[0] - pre_range] + ).astype(int) # interpolation window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in x_interpol + x_interpol = np.arange( + intpol_window[0], intpol_window[1] + 1, 1 + ) # points to be interpolated in pt - the gap between the endpoints of the window + x_fit = np.concatenate( + [ + np.arange( + intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1 + ), + np.arange( + intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1 + ), + ] + ) # Entire range of x values in this step (taking some number of samples before and after the window) + y_fit = fitted_art[0, x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # Then make fitted artefact in the desired range equal to the completed fit above + fitted_art[0, post_idx_previous_peak: a_peak_idx[0] - pre_range + 1] = ( + y_interpol + ) + + return fitted_art, a_peak_idx[0] + post_range + + +def apply_pca_obs( + raw: Raw, + picks: list[str], + qrs: np.ndarray, + filter_coords: np.ndarray, + n_components: int = 4, + n_jobs: Optional[int] = None, +) -> None: + """ + Main convenience function for applying the PCA-OBS algorithm + to certain picks of a Raw object. Updates the Raw object in-place. + Makes sanity checks for all inputs. + + Parameters + ---------- + raw: Raw + The raw data to process + picks: list[str] + Channels in the Raw object to remove the heart artefact from + qrs: ndarray, shape (n_peaks, 1) + Array of times in (s), of detected R-peaks in ECG channel. + filter_coords: ndarray (N, ) + The numerator coefficient vector of the filter passed to scipy.signal.filtfilt + n_components: int, default 4 + Number of PCA components to use to form the OBS + n_jobs: int, default None + Number of jobs to perform the PCA-OBS processing in parallel + """ + + if not qrs: + raise ValueError("qrs must not be empty") + + if not filter_coords: + raise ValueError("filter_coords must not be empty") + + raw.apply_function( + _pca_obs, + picks=picks, + n_jobs=n_jobs, + # args sent to PCA_OBS + qrs=qrs, + filter_coords=filter_coords, + n_components=n_components, + ) + +def _pca_obs( + data: np.ndarray, + qrs: np.ndarray, + filter_coords: np.ndarray, + n_components: int, +) -> np.ndarray: + """ + Algorithm to perform the PCA OBS (Principal Component Analysis, Optimal Basis Sets) + algorithm to remove the heart artefact from EEG data (shape [n_channels, n_times]) + """ + + # set to baseline + data = data.reshape(-1, 1) + data = data.T + data = data - np.mean(data, axis=1) + + # Allocate memory + fitted_art = np.zeros(data.shape) + peakplot = np.zeros(data.shape) + + # Extract QRS events + for idx in qrs[0]: + if idx < len(peakplot[0, :]): + peakplot[0, idx] = 1 # logical indexed locations of qrs events + + peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns + peak_idx = peak_idx.reshape(-1, 1) + peak_count = len(peak_idx) + + ################################################################## + # Preparatory work - reserving memory, configure sizes, de-trend # + ################################################################## + logger.info("Pulse artifact subtraction in progress...Please wait!") + + # define peak range based on RR + RR = np.diff(peak_idx[:, 0]) + mRR = np.median(RR) + peak_range = round(mRR / 2) # Rounds to an integer + mid_p = peak_range + 1 + n_samples_fit = round( + peak_range / 8 + ) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + while peak_idx[peak_count - 1, 0] + peak_range > len(data[0]): + peak_count = peak_count - 1 # reduce number of QRS complexes detected + + # Filter channel + eegchan = filtfilt(filter_coords, 1, data) + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p - 1, :] = eegchan[ + 0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1 + ] + + # detrending matrix(twice) + pcamat = detrend( + pcamat, type="constant", axis=1 + ) # [epoch x time] - detrended along the epoch + mean_effect: np.ndarray = np.mean( + pcamat, axis=0 + ) # [1 x time], contains the mean over all epochs + dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] + + ############################ + # Perform PCA with sklearn # + ############################ + # run PCA, perform singular value decomposition (SVD) + pca = PCA(svd_solver="full") + pca.fit(dpcamat) + factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) + + # define selected number of components using profile likelihood + + ##################################### + # Make template of the ECG artefact # + ##################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, :n_components]] + + ################ + # Data Fitting # + ################ + window_start_idx = [] + window_end_idx = [] + for p in range(peak_count): + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if post_range > peak_range: + post_range = peak_range + try: + post_idx_nextPeak = None + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + warn(f"Cannot fit first ECG epoch. Reason: {e}") + + # Deals with last edge of data + elif p == peak_count-1: + logger.info("On last section - almost there!") + try: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + warn(f"Cannot fit last ECG epoch. Reason: {e}") + + # Deals with middle portion of data + else: + try: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_nextPeak = fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_nextPeak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + except Exception as e: + warn(f"Cannot fit middle section of data. Reason: {e}") + + # Actually subtract the artefact, return needs to be the same shape as input data + data = data.reshape(-1) + fitted_art = fitted_art.reshape(-1) + + data -= fitted_art + data = data.T.reshape(-1) + + return data diff --git a/mne/preprocessing/pca_obs/PCA_OBS.py b/mne/preprocessing/pca_obs/PCA_OBS.py deleted file mode 100755 index aee06165d11..00000000000 --- a/mne/preprocessing/pca_obs/PCA_OBS.py +++ /dev/null @@ -1,185 +0,0 @@ -import numpy as np -# import mne -from scipy.signal import filtfilt, detrend -import matplotlib.pyplot as plt -from sklearn.decomposition import PCA -from fit_ecgTemplate import fit_ecgTemplate -import math -import h5py - - -def PCA_OBS(data, **kwargs): - - # Declare class to hold pca information - class PCAInfo(): - def __init__(self): - pass - - # Instantiate class - pca_info = PCAInfo() - - # Check all necessary arguments sent in - required_kws = ["qrs", "filter_coords", "sr"] - assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." - - # Extract all kwargs - qrs = kwargs['qrs'] - filter_coords = kwargs['filter_coords'] - sr = kwargs['sr'] - - fs = sr - - # set to baseline - data = data.reshape(-1, 1) - data = data.T - data = data - np.mean(data, axis=1) - - # Allocate memory - fitted_art = np.zeros(data.shape) - peakplot = np.zeros(data.shape) - - # Extract QRS events - for idx in qrs[0]: - if idx < len(peakplot[0, :]): - peakplot[0, idx] = 1 # logical indexed locations of qrs events - - peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns - peak_idx = peak_idx.reshape(-1, 1) - peak_count = len(peak_idx) - - ################################################################ - # Preparatory work - reserving memory, configure sizes, de-trend - ################################################################ - print('Pulse artifact subtraction in progress...Please wait!') - - # define peak range based on RR - RR = np.diff(peak_idx[:, 0]) - mRR = np.median(RR) - peak_range = round(mRR/2) # Rounds to an integer - midP = peak_range + 1 - baseline_range = [0, round(peak_range/8)] - n_samples_fit = round(peak_range/8) # sample fit for interpolation between fitted artifact windows - - # make sure array is long enough for PArange (if not cut off last ECG peak) - pa = peak_count # Number of QRS complexes detected - while peak_idx[pa-1, 0] + peak_range > len(data[0]): - pa = pa - 1 - steps = 1 * pa - peak_count = pa - - # Filter channel - eegchan = filtfilt(filter_coords, 1, data) - - # build PCA matrix(heart-beat-epochs x window-length) - pcamat = np.zeros((peak_count - 1, 2*peak_range+1)) # [epoch x time] - # picking out heartbeat epochs - for p in range(1, peak_count): - pcamat[p-1, :] = eegchan[0, peak_idx[p, 0] - peak_range: peak_idx[p, 0] + peak_range+1] - - # detrending matrix(twice) - pcamat = detrend(pcamat, type='constant', axis=1) # [epoch x time] - detrended along the epoch - mean_effect = np.mean(pcamat, axis=0) # [1 x time], contains the mean over all epochs - std_effect = np.std(pcamat, axis=0) # want mean and std of each column - dpcamat = detrend(pcamat, type='constant', axis=1) # [time x epoch] - - ################################################################### - # Perform PCA with sklearn - ################################################################### - # run PCA(performs SVD(singular value decomposition)) - pca = PCA(svd_solver="full") - pca.fit(dpcamat) - eigen_vectors = pca.components_ - eigen_values = pca.explained_variance_ - factor_loadings = pca.components_.T*np.sqrt(pca.explained_variance_) - pca_info.eigen_vectors = eigen_vectors - pca_info.factor_loadings = factor_loadings - pca_info.eigen_values = eigen_values - pca_info.expl_var = pca.explained_variance_ratio_ - - # define selected number of components using profile likelihood - pca_info.nComponents = 4 - pca_info.meanEffect = mean_effect.T - nComponents = pca_info.nComponents - - ####################################################################### - # Make template of the ECG artefact - ####################################################################### - mean_effect = mean_effect.reshape(-1, 1) - pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]] - - ################################################################################### - # Data Fitting - ################################################################################### - window_start_idx = [] - window_end_idx = [] - for p in range(0, peak_count): - # Deals with start portion of data - if p == 0: - pre_range = peak_range - post_range = math.floor((peak_idx[p + 1] - peak_idx[p])/2) - if post_range > peak_range: - post_range = peak_range - try: - post_idx_nextPeak = [] - fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, pca_template, peak_idx[p], peak_range, - pre_range, post_range, baseline_range, midP, - fitted_art, post_idx_nextPeak, n_samples_fit) - # Appending to list instead of using counter - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - print(f'Cannot fit first ECG epoch. Reason: {e}') - - # Deals with last edge of data - elif p == peak_count: - print('On last section - almost there!') - try: - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = peak_range - if pre_range > peak_range: - pre_range = peak_range - fitted_art, _ = fit_ecgTemplate(data, pca_template, peak_idx(p), peak_range, pre_range, post_range, - baseline_range, midP, fitted_art, post_idx_nextPeak, n_samples_fit) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - print(f'Cannot fit last ECG epoch. Reason: {e}') - - # Deals with middle portion of data - else: - try: - # ---------------- Processing of central data - -------------------- - # cycle through peak artifacts identified by peakplot - pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) - post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) - if pre_range >= peak_range: - pre_range = peak_range - if post_range > peak_range: - post_range = peak_range - - aTemplate = pca_template[midP - peak_range-1:midP + peak_range+1, :] - fitted_art, post_idx_nextPeak = fit_ecgTemplate(data, aTemplate, peak_idx[p], peak_range, pre_range, - post_range, baseline_range, midP, fitted_art, - post_idx_nextPeak, n_samples_fit) - window_start_idx.append(peak_idx[p] - peak_range) - window_end_idx.append(peak_idx[p] + peak_range) - except Exception as e: - print(f"Cannot fit middle section of data. Reason: {e}") - - # Actually subtract the artefact, return needs to be the same shape as input data - # One sample shift purely due to the fact the r-peaks are currently detected in MATLAB - data = data.reshape(-1) - fitted_art = fitted_art.reshape(-1) - - # One sample shift for my actual data (introduced using matlab r timings) - # data_ = np.zeros(len(data)) - # data_[0] = data[0] - # data_[1:] = data[1:] - fitted_art[:-1] - # data = data_ - - # Original code is this: - data -= fitted_art - data = data.T.reshape(-1) - - # Can only return data - return data diff --git a/mne/preprocessing/pca_obs/fit_ecgTemplate.py b/mne/preprocessing/pca_obs/fit_ecgTemplate.py deleted file mode 100755 index 11043b0c56d..00000000000 --- a/mne/preprocessing/pca_obs/fit_ecgTemplate.py +++ /dev/null @@ -1,71 +0,0 @@ -import numpy as np -from scipy.signal import detrend -from scipy.interpolate import PchipInterpolator as pchip -import h5py - - -def fit_ecgTemplate(data, pca_template, aPeak_idx, peak_range, pre_range, post_range, baseline_range, midP, fitted_art, post_idx_previousPeak, n_samples_fit): - # Declare class to hold ecg fit information - class fitECG(): - def __init__(self): - pass - - # Instantiate class - fitecg = fitECG() - - # post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak - # Then nextpeak is returned at the end and the process repeats - # select window of template - template = pca_template[midP - peak_range-1: midP + peak_range+1, :] - - # select window of data and detrend it - slice = data[0, aPeak_idx[0] - peak_range:aPeak_idx[0] + peak_range+1] - detrended_data = detrend(slice.reshape(-1), type='constant') - - # maps data on template and then maps it again back to the sensor space - least_square = np.linalg.lstsq(template, detrended_data, rcond=None) - pad_fit = np.dot(template, least_square[0]) - - # fit artifact, I already loop through externally channel to channel - fitted_art[0, aPeak_idx[0] - pre_range-1: aPeak_idx[0] + post_range] = pad_fit[midP - pre_range-1: midP + post_range].T - - fitecg.fitted_art = fitted_art - fitecg.template = template - fitecg.detrended_data = detrended_data - fitecg.pad_fit = pad_fit - fitecg.aPeak_idx = aPeak_idx - fitecg.midP = midP - fitecg.peak_range = peak_range - fitecg.data = data - - post_idx_nextPeak = [aPeak_idx[0] + post_range] - - # Check it's not empty - if len(post_idx_previousPeak) != 0: - # interpolate time between peaks - intpol_window = np.ceil([post_idx_previousPeak[0], aPeak_idx[0] - pre_range]).astype('int') # interpolation window - fitecg.intpol_window = intpol_window - - if intpol_window[0] < intpol_window[1]: - # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data - - # You have x_fit which is two slices on either side of the interpolation window endpoints - # You have y_fit which is the y vals corresponding to x values above - # You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate - # You have y_interpol which is values from pchip at the time points specified in x_interpol - x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) # points to be interpolated in pt - the gap between the endpoints of the window - x_fit = np.concatenate([np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), - np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1)]) # Entire range of x values in this step (taking some number of samples before and after the window) - y_fit = fitted_art[0, x_fit] - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - - # Then make fitted artefact in the desired range equal to the completed fit above - fitted_art[0, post_idx_previousPeak[0]: aPeak_idx[0] - pre_range+1] = y_interpol - - fitecg.x_fit = x_fit - fitecg.y_fit = y_fit - fitecg.x_interpol = x_interpol - fitecg.y_interpol = y_interpol - fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop - - return fitted_art, post_idx_nextPeak diff --git a/mne/preprocessing/pca_obs/pchip_interpolation.py b/mne/preprocessing/pca_obs/pchip_interpolation.py deleted file mode 100755 index e336f672363..00000000000 --- a/mne/preprocessing/pca_obs/pchip_interpolation.py +++ /dev/null @@ -1,42 +0,0 @@ -# Function to interpolate based on PCHIP rather than MNE inbuilt linear option - -# import mne -import numpy as np -from scipy.interpolate import PchipInterpolator as pchip -import matplotlib.pyplot as plt - - -def PCHIP_interpolation(data, **kwargs): - # Check all necessary arguments sent in - required_kws = ["trigger_indices", "interpol_window_sec", "fs"] - assert all([kw in kwargs.keys() for kw in required_kws]), "Error. Some KWs not passed into PCA_OBS." - - # Extract all kwargs - more elegant ways to do this - fs = kwargs['fs'] - interpol_window_sec = kwargs['interpol_window_sec'] - trigger_indices = kwargs['trigger_indices'] - - # Convert intpol window to msec then convert to samples - pre_window = round((interpol_window_sec[0]*1000) * fs / 1000) # in samples - post_window = round((interpol_window_sec[1]*1000) * fs / 1000) # in samples - intpol_window = np.ceil([pre_window, post_window]).astype(int) # interpolation window - - n_samples_fit = 5 # number of samples before and after cut used for interpolation fit - - x_fit_raw = np.concatenate([np.arange(intpol_window[0]-n_samples_fit-1, intpol_window[0], 1), - np.arange(intpol_window[1]+1, intpol_window[1]+n_samples_fit+2, 1)]) - x_interpol_raw = np.arange(intpol_window[0], intpol_window[1]+1, 1) # points to be interpolated; in pt - - for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events - x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event - x_interpol = trigger_indices[ii] + x_interpol_raw # latencies for to-be-interpolated data points - - # Data is just a string of values - y_fit = data[x_fit] # y values to be fitted - y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation - data[x_interpol] = y_interpol # replace in data - - if np.mod(ii, 100) == 0: # talk to the operator every 100th trial - print(f'stimulation event {ii} \n') - - return data diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py b/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py deleted file mode 100644 index 42e58e99463..00000000000 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_cortical_mnedata.py +++ /dev/null @@ -1,77 +0,0 @@ -# Checking algorithm implementation with EEG data form MNE sample datasets -# Following this tutorial data: https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#what-is-ssp - -from mne.preprocessing import ( - find_ecg_events, - create_ecg_epochs, -) -from mne.io import read_raw_fif -from mne.datasets.sample import data_path -from PCA_OBS import * -from mne import Epochs -import os -import numpy as np -from scipy.signal import firls -import matplotlib.pyplot as plt - -if __name__ == '__main__': - sample_data_folder = data_path(path='/data/pt_02569/mne_test_data/') - sample_data_raw_file = os.path.join( - sample_data_folder, "MEG", "sample", "sample_audvis_raw.fif" - ) - # here we crop and resample just for speed - raw = read_raw_fif(sample_data_raw_file, preload=True) - - # Find ECG events - no ECG channel in data, uses synthetic - ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw) - # Extract just sample timings of ecg events - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - # print(ecg_events) - - # Create filter coefficients - fs = raw.info['sfreq'] - a = [0, 0, 1, 1] - f = [0, 0.4 / (fs / 2), 0.9 / (fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3 * fs / 0.5) - fwts = firls(ord + 1, f, a) - - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] - - # run PCA_OBS - # Algorithm is extremely sensitive to accurate R-peak timings, won't work as well with the artificial ECG - # channel estimation as we have here - PCA_OBS_kwargs = dict( - qrs=ecg_event_samples, filter_coords=fwts, sr=fs - ) - - epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks='eeg', **PCA_OBS_kwargs, n_jobs=10) - epochs = Epochs(raw, ecg_events, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-1e-5, 3e-5]) - axes[0].set_title('Before PCA-OBS') - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-1e-5, 3e-5]) - axes[1].set_title('After PCA-OBS') - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') - axes.set_ylim([-1e-5, 3e-5]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') - plt.tight_layout() - plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py deleted file mode 100755 index 7634096c909..00000000000 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_impreciserpeak.py +++ /dev/null @@ -1,80 +0,0 @@ -# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component -# Analysis, Optimal Basis Sets) - -import os -from scipy.io import loadmat -from scipy.signal import firls -from PCA_OBS import * -from mne.io import read_raw_fif -from mne import events_from_annotations, Epochs -from mne.preprocessing import find_ecg_events - - -if __name__ == '__main__': - # Incredibly slow without parallelization - # Set variables - subject_id = f'sub-001' - cond_name = 'median' - sampling_rate = 1000 - esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', - 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', - 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', - 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', - 'S23'] - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] - - # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + '.fif', preload=True) - - # Find ECG events - no ECG channel in data, uses synthetic - ecg_events, ch_ecg, average_pulse, = find_ecg_events(raw, ch_name='ECG') - # Extract just sample timings of ecg events - ecg_event_samples = np.asarray([[ecg_event[0] for ecg_event in ecg_events]]) - - # Create filter coefficients - fs = sampling_rate - a = [0, 0, 1, 1] - f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3*fs/0.5) - fwts = firls(ord+1, f, a) - - # run PCA_OBS - PCA_OBS_kwargs = dict( - qrs=ecg_event_samples, filter_coords=fwts, sr=sampling_rate - ) - - events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title('Before PCA-OBS') - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title('After PCA-OBS') - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') - axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') - plt.tight_layout() - plt.show() diff --git a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py b/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py deleted file mode 100755 index 1b559ef9f22..00000000000 --- a/mne/preprocessing/pca_obs/rm_heart_artefact_spinal_preciserpeak.py +++ /dev/null @@ -1,80 +0,0 @@ -# Calls PCA_OBS which in turn calls fit_ecgTemplate to remove the heart artefact via PCA_OBS (Principal Component -# Analysis, Optimal Basis Sets) - -import os -from scipy.io import loadmat -from scipy.signal import firls -from PCA_OBS import * -from mne.io import read_raw_fif -from mne import events_from_annotations, Epochs - - -if __name__ == '__main__': - # Incredibly slow without parallelization - # Set variables - subject_id = f'sub-001' - cond_name = 'median' - sampling_rate = 1000 - esg_chans = ['S35', 'S24', 'S36', 'Iz', 'S17', 'S15', 'S32', 'S22', - 'S19', 'S26', 'S28', 'S9', 'S13', 'S11', 'S7', 'SC1', 'S4', 'S18', - 'S8', 'S31', 'SC6', 'S12', 'S16', 'S5', 'S30', 'S20', 'S34', 'AC', - 'S21', 'S25', 'L1', 'S29', 'S14', 'S33', 'S3', 'AL', 'L4', 'S6', - 'S23'] - # For heartbeat epochs - iv_baseline = [-300 / 1000, -200 / 1000] - iv_epoch = [-400 / 1000, 600 / 1000] - - # Setting paths - input_path = "/data/pt_02569/tmp_data/prepared_py/"+subject_id+"/esg/prepro/" - input_path_m = "/data/pt_02569/tmp_data/prepared/" + subject_id + "/esg/prepro/" - fname = f"noStimart_sr{sampling_rate}_{cond_name}_withqrs_pchip" - raw = read_raw_fif(input_path + fname + '.fif', preload=True) - - # Read .mat file with QRS events - fname_m = f"raw_{sampling_rate}_spinal_{cond_name}" - matdata = loadmat(input_path_m+fname_m+'.mat') - QRSevents_m = matdata['QRSevents'] - - # Create filter coefficients - fs = sampling_rate - a = [0, 0, 1, 1] - f = [0, 0.4/(fs/2), 0.9/(fs / 2), 1] # 0.9 Hz highpass filter - # f = [0 0.4 / (fs / 2) 0.5 / (fs / 2) 1] # 0.5 Hz highpass filter - ord = round(3*fs/0.5) - fwts = firls(ord+1, f, a) - - # run PCA_OBS - PCA_OBS_kwargs = dict( - qrs=QRSevents_m, filter_coords=fwts, sr=sampling_rate - ) - - events, event_ids = events_from_annotations(raw) - event_id_dict = {key: value for key, value in event_ids.items() if key == 'qrs'} - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) - evoked_before = epochs.average() - - # Apply function should modifies the data in raw in place - raw.apply_function(PCA_OBS, picks=esg_chans, **PCA_OBS_kwargs, n_jobs=len(esg_chans)) - epochs = Epochs(raw, events, event_id=event_id_dict, tmin=iv_epoch[0], tmax=iv_epoch[1], - baseline=tuple(iv_baseline)) - evoked_after = epochs.average() - - # Comparison image - fig, axes = plt.subplots(2, 1) - axes[0].plot(evoked_before.times, evoked_before.get_data().T) - axes[0].set_ylim([-0.0005, 0.001]) - axes[0].set_title('Before PCA-OBS') - axes[1].plot(evoked_after.times, evoked_after.get_data().T) - axes[1].set_ylim([-0.0005, 0.001]) - axes[1].set_title('After PCA-OBS') - plt.tight_layout() - - # Comparison image - fig, axes = plt.subplots(1, 1) - axes.plot(evoked_before.times, evoked_before.get_data().T, color='black') - axes.set_ylim([-0.0005, 0.001]) - axes.plot(evoked_after.times, evoked_after.get_data().T, color='green') - axes.set_title('Before (black) versus after (green)') - plt.tight_layout() - plt.show() diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py new file mode 100644 index 00000000000..06d45062d6d --- /dev/null +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -0,0 +1,44 @@ +"""Test the ieeg projection functions.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# TODO: migrate this structure to test out function + +import pytest + +from mne.io import read_raw_fif +from mne.preprocessing.pca_obs import pca_obs +from mne.datasets.testing import data_path, requires_testing_data + +# TODO: Where are the test files we want to use located? +fname = data_path(download=False) / "eyetrack" / "test_eyelink.asc" + +@requires_testing_data +@pytest.mark.parametrize( + # TODO: Are there any parameters we can cycle through to + # test multiple? Different fs, windows, highpass freqs, etc.? + # TODO: how do we determine qrs and filter_coords? What are these? + "fs, highpass_freq, qrs, filter_coords", + [ + (0.2, 1.0, 100, 200), + (0.1, 2.0, 100, 200), + ], +) +def test_heart_artifact_removal(fs, highpass_freq, qrs, filter_coords): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + raw = read_raw_fif(fname) + + # Do something with fs and highpass as processing of the data? + ... + + # call pca_obs algorithm + result = pca_obs(raw, qrs=qrs, filter_coords=filter_coords) + + # assert results + assert result is not None + assert result.shape == (100, 100) + assert result.shape == raw.shape # is this a condition we can test? + assert result[0, 0] == 1.0 + ... \ No newline at end of file